."),
+ ET.XML("And finally, here is an embedded XHTML fragment.
"),
+ )
+ )
+ )
+
+ print ET.tostring(page)
+
+ Here's a prettyprinted version of the output from the above script::
+
+
+
+ This is a sample document
+
+
+ Hello!
+ This is a paragraph with bold text in it!
+ This is another paragraph, with link .
+ Here are some reserved characters: <spam&egg>.
+ And finally, here is an embedded XHTML fragment.
+
+
+
+ For namespace support, you can pass a namespace map (``nsmap``)
+ and/or a specific target ``namespace`` to the ElementMaker class::
+
+ >>> E = ElementMaker(namespace="http://my.ns/")
+ >>> print(ET.tostring( E.test ))
+
+
+ >>> E = ElementMaker(namespace="http://my.ns/", nsmap={'p':'http://my.ns/'})
+ >>> print(ET.tostring( E.test ))
+
+ """
+
+ def __init__(self, typemap=None,
+ namespace=None, nsmap=None, makeelement=None):
+ self._namespace = '{' + namespace + '}' if namespace is not None else None
+ self._nsmap = dict(nsmap) if nsmap else None
+
+ assert makeelement is None or callable(makeelement)
+ self._makeelement = makeelement if makeelement is not None else ET.Element
+
+ # initialize the default type map functions for this element factory
+ typemap = dict(typemap) if typemap else {}
+
+ def add_text(elem, item):
+ try:
+ last_child = elem[-1]
+ except IndexError:
+ elem.text = (elem.text or "") + item
+ else:
+ last_child.tail = (last_child.tail or "") + item
+
+ def add_cdata(elem, cdata):
+ if elem.text:
+ raise ValueError("Can't add a CDATA section. Element already has some text: %r" % elem.text)
+ elem.text = cdata
+
+ if str not in typemap:
+ typemap[str] = add_text
+ if unicode not in typemap:
+ typemap[unicode] = add_text
+ if ET.CDATA not in typemap:
+ typemap[ET.CDATA] = add_cdata
+
+ def add_dict(elem, item):
+ attrib = elem.attrib
+ for k, v in item.items():
+ if isinstance(v, basestring):
+ attrib[k] = v
+ else:
+ attrib[k] = typemap[type(v)](None, v)
+
+ if dict not in typemap:
+ typemap[dict] = add_dict
+
+ self._typemap = typemap
+
+ def __call__(self, tag, *children, **attrib):
+ typemap = self._typemap
+
+ # We'll usually get a 'str', and the compiled type check is very fast.
+ if not isinstance(tag, str) and isinstance(tag, _QName):
+ # A QName is explicitly qualified, do not look at self._namespace.
+ tag = tag.text
+ elif self._namespace is not None and tag[0] != '{':
+ tag = self._namespace + tag
+ elem = self._makeelement(tag, nsmap=self._nsmap)
+ if attrib:
+ typemap[dict](elem, attrib)
+
+ for item in children:
+ if callable(item):
+ item = item()
+ t = typemap.get(type(item))
+ if t is None:
+ if ET.iselement(item):
+ elem.append(item)
+ continue
+ for basetype in type(item).__mro__:
+ # See if the typemap knows of any of this type's bases.
+ t = typemap.get(basetype)
+ if t is not None:
+ break
+ else:
+ raise TypeError("bad argument type: %s(%r)" %
+ (type(item).__name__, item))
+ v = t(elem, item)
+ if v:
+ typemap.get(type(v))(elem, v)
+
+ return elem
+
+ def __getattr__(self, tag):
+ return partial(self, tag)
+
+ # Allow subscripting ElementMaker in type annotions (PEP 560)
+ def __class_getitem__(cls, item):
+ return _GenericAlias(cls, item)
+
+
+# create factory object
+E = ElementMaker()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/classlookup.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/classlookup.pxi
new file mode 100644
index 0000000000000000000000000000000000000000..92d1d47a58657a7741d20f48cfe3525a66dbc722
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/classlookup.pxi
@@ -0,0 +1,580 @@
+# Configurable Element class lookup
+
+################################################################################
+# Custom Element classes
+
+cdef public class ElementBase(_Element) [ type LxmlElementBaseType,
+ object LxmlElementBase ]:
+ """ElementBase(*children, attrib=None, nsmap=None, **_extra)
+
+ The public Element class. All custom Element classes must inherit
+ from this one. To create an Element, use the `Element()` factory.
+
+ BIG FAT WARNING: Subclasses *must not* override __init__ or
+ __new__ as it is absolutely undefined when these objects will be
+ created or destroyed. All persistent state of Elements must be
+ stored in the underlying XML. If you really need to initialize
+ the object after creation, you can implement an ``_init(self)``
+ method that will be called directly after object creation.
+
+ Subclasses of this class can be instantiated to create a new
+ Element. By default, the tag name will be the class name and the
+ namespace will be empty. You can modify this with the following
+ class attributes:
+
+ * TAG - the tag name, possibly containing a namespace in Clark
+ notation
+
+ * NAMESPACE - the default namespace URI, unless provided as part
+ of the TAG attribute.
+
+ * HTML - flag if the class is an HTML tag, as opposed to an XML
+ tag. This only applies to un-namespaced tags and defaults to
+ false (i.e. XML).
+
+ * PARSER - the parser that provides the configuration for the
+ newly created document. Providing an HTML parser here will
+ default to creating an HTML element.
+
+ In user code, the latter three are commonly inherited in class
+ hierarchies that implement a common namespace.
+ """
+ def __init__(self, *children, attrib=None, nsmap=None, **_extra):
+ """ElementBase(*children, attrib=None, nsmap=None, **_extra)
+ """
+ cdef bint is_html = 0
+ cdef _BaseParser parser
+ cdef _Element last_child
+ # don't use normal attribute access as it might be overridden
+ _getattr = object.__getattribute__
+ try:
+ namespace = _utf8(_getattr(self, 'NAMESPACE'))
+ except AttributeError:
+ namespace = None
+ try:
+ ns, tag = _getNsTag(_getattr(self, 'TAG'))
+ if ns is not None:
+ namespace = ns
+ except AttributeError:
+ tag = _utf8(_getattr(_getattr(self, '__class__'), '__name__'))
+ if b'.' in tag:
+ tag = tag.split(b'.')[-1]
+ try:
+ parser = _getattr(self, 'PARSER')
+ except AttributeError:
+ parser = None
+ for child in children:
+ if isinstance(child, _Element):
+ parser = (<_Element>child)._doc._parser
+ break
+ if isinstance(parser, HTMLParser):
+ is_html = 1
+ if namespace is None:
+ try:
+ is_html = _getattr(self, 'HTML')
+ except AttributeError:
+ pass
+ _initNewElement(self, is_html, tag, namespace, parser,
+ attrib, nsmap, _extra)
+ last_child = None
+ for child in children:
+ if _isString(child):
+ if last_child is None:
+ _setNodeText(self._c_node,
+ (_collectText(self._c_node.children) or '') + child)
+ else:
+ _setTailText(last_child._c_node,
+ (_collectText(last_child._c_node.next) or '') + child)
+ elif isinstance(child, _Element):
+ last_child = child
+ _appendChild(self, last_child)
+ elif isinstance(child, type) and issubclass(child, ElementBase):
+ last_child = child()
+ _appendChild(self, last_child)
+ else:
+ raise TypeError, f"Invalid child type: {type(child)!r}"
+
+cdef class CommentBase(_Comment):
+ """All custom Comment classes must inherit from this one.
+
+ To create an XML Comment instance, use the ``Comment()`` factory.
+
+ Subclasses *must not* override __init__ or __new__ as it is
+ absolutely undefined when these objects will be created or
+ destroyed. All persistent state of Comments must be stored in the
+ underlying XML. If you really need to initialize the object after
+ creation, you can implement an ``_init(self)`` method that will be
+ called after object creation.
+ """
+ def __init__(self, text):
+ # copied from Comment() factory
+ cdef _Document doc
+ cdef xmlDoc* c_doc
+ if text is None:
+ text = b''
+ else:
+ text = _utf8(text)
+ c_doc = _newXMLDoc()
+ doc = _documentFactory(c_doc, None)
+ self._c_node = _createComment(c_doc, _xcstr(text))
+ if self._c_node is NULL:
+ raise MemoryError()
+ tree.xmlAddChild(c_doc, self._c_node)
+ _registerProxy(self, doc, self._c_node)
+ self._init()
+
+cdef class PIBase(_ProcessingInstruction):
+ """All custom Processing Instruction classes must inherit from this one.
+
+ To create an XML ProcessingInstruction instance, use the ``PI()``
+ factory.
+
+ Subclasses *must not* override __init__ or __new__ as it is
+ absolutely undefined when these objects will be created or
+ destroyed. All persistent state of PIs must be stored in the
+ underlying XML. If you really need to initialize the object after
+ creation, you can implement an ``_init(self)`` method that will be
+ called after object creation.
+ """
+ def __init__(self, target, text=None):
+ # copied from PI() factory
+ cdef _Document doc
+ cdef xmlDoc* c_doc
+ target = _utf8(target)
+ if text is None:
+ text = b''
+ else:
+ text = _utf8(text)
+ c_doc = _newXMLDoc()
+ doc = _documentFactory(c_doc, None)
+ self._c_node = _createPI(c_doc, _xcstr(target), _xcstr(text))
+ if self._c_node is NULL:
+ raise MemoryError()
+ tree.xmlAddChild(c_doc, self._c_node)
+ _registerProxy(self, doc, self._c_node)
+ self._init()
+
+cdef class EntityBase(_Entity):
+ """All custom Entity classes must inherit from this one.
+
+ To create an XML Entity instance, use the ``Entity()`` factory.
+
+ Subclasses *must not* override __init__ or __new__ as it is
+ absolutely undefined when these objects will be created or
+ destroyed. All persistent state of Entities must be stored in the
+ underlying XML. If you really need to initialize the object after
+ creation, you can implement an ``_init(self)`` method that will be
+ called after object creation.
+ """
+ def __init__(self, name):
+ cdef _Document doc
+ cdef xmlDoc* c_doc
+ name_utf = _utf8(name)
+ c_name = _xcstr(name_utf)
+ if c_name[0] == c'#':
+ if not _characterReferenceIsValid(c_name + 1):
+ raise ValueError, f"Invalid character reference: '{name}'"
+ elif not _xmlNameIsValid(c_name):
+ raise ValueError, f"Invalid entity reference: '{name}'"
+ c_doc = _newXMLDoc()
+ doc = _documentFactory(c_doc, None)
+ self._c_node = _createEntity(c_doc, c_name)
+ if self._c_node is NULL:
+ raise MemoryError()
+ tree.xmlAddChild(c_doc, self._c_node)
+ _registerProxy(self, doc, self._c_node)
+ self._init()
+
+
+cdef int _validateNodeClass(xmlNode* c_node, cls) except -1:
+ if c_node.type == tree.XML_ELEMENT_NODE:
+ expected = ElementBase
+ elif c_node.type == tree.XML_COMMENT_NODE:
+ expected = CommentBase
+ elif c_node.type == tree.XML_ENTITY_REF_NODE:
+ expected = EntityBase
+ elif c_node.type == tree.XML_PI_NODE:
+ expected = PIBase
+ else:
+ assert False, f"Unknown node type: {c_node.type}"
+
+ if not (isinstance(cls, type) and issubclass(cls, expected)):
+ raise TypeError(
+ f"result of class lookup must be subclass of {type(expected)}, got {type(cls)}")
+ return 0
+
+
+################################################################################
+# Element class lookup
+
+ctypedef public object (*_element_class_lookup_function)(object, _Document, xmlNode*)
+
+# class to store element class lookup functions
+cdef public class ElementClassLookup [ type LxmlElementClassLookupType,
+ object LxmlElementClassLookup ]:
+ """ElementClassLookup(self)
+ Superclass of Element class lookups.
+ """
+ cdef _element_class_lookup_function _lookup_function
+
+
+cdef public class FallbackElementClassLookup(ElementClassLookup) \
+ [ type LxmlFallbackElementClassLookupType,
+ object LxmlFallbackElementClassLookup ]:
+ """FallbackElementClassLookup(self, fallback=None)
+
+ Superclass of Element class lookups with additional fallback.
+ """
+ cdef readonly ElementClassLookup fallback
+ cdef _element_class_lookup_function _fallback_function
+ def __cinit__(self):
+ # fall back to default lookup
+ self._fallback_function = _lookupDefaultElementClass
+
+ def __init__(self, ElementClassLookup fallback=None):
+ if fallback is not None:
+ self._setFallback(fallback)
+ else:
+ self._fallback_function = _lookupDefaultElementClass
+
+ cdef void _setFallback(self, ElementClassLookup lookup):
+ """Sets the fallback scheme for this lookup method.
+ """
+ self.fallback = lookup
+ self._fallback_function = lookup._lookup_function
+ if self._fallback_function is NULL:
+ self._fallback_function = _lookupDefaultElementClass
+
+ def set_fallback(self, ElementClassLookup lookup not None):
+ """set_fallback(self, lookup)
+
+ Sets the fallback scheme for this lookup method.
+ """
+ self._setFallback(lookup)
+
+cdef inline object _callLookupFallback(FallbackElementClassLookup lookup,
+ _Document doc, xmlNode* c_node):
+ return lookup._fallback_function(lookup.fallback, doc, c_node)
+
+
+################################################################################
+# default lookup scheme
+
+cdef class ElementDefaultClassLookup(ElementClassLookup):
+ """ElementDefaultClassLookup(self, element=None, comment=None, pi=None, entity=None)
+ Element class lookup scheme that always returns the default Element
+ class.
+
+ The keyword arguments ``element``, ``comment``, ``pi`` and ``entity``
+ accept the respective Element classes.
+ """
+ cdef readonly object element_class
+ cdef readonly object comment_class
+ cdef readonly object pi_class
+ cdef readonly object entity_class
+ def __cinit__(self):
+ self._lookup_function = _lookupDefaultElementClass
+
+ def __init__(self, element=None, comment=None, pi=None, entity=None):
+ if element is None:
+ self.element_class = _Element
+ elif issubclass(element, ElementBase):
+ self.element_class = element
+ else:
+ raise TypeError, "element class must be subclass of ElementBase"
+
+ if comment is None:
+ self.comment_class = _Comment
+ elif issubclass(comment, CommentBase):
+ self.comment_class = comment
+ else:
+ raise TypeError, "comment class must be subclass of CommentBase"
+
+ if entity is None:
+ self.entity_class = _Entity
+ elif issubclass(entity, EntityBase):
+ self.entity_class = entity
+ else:
+ raise TypeError, "Entity class must be subclass of EntityBase"
+
+ if pi is None:
+ self.pi_class = None # special case, see below
+ elif issubclass(pi, PIBase):
+ self.pi_class = pi
+ else:
+ raise TypeError, "PI class must be subclass of PIBase"
+
+cdef object _lookupDefaultElementClass(state, _Document _doc, xmlNode* c_node):
+ "Trivial class lookup function that always returns the default class."
+ if c_node.type == tree.XML_ELEMENT_NODE:
+ if state is not None:
+ return (state).element_class
+ else:
+ return _Element
+ elif c_node.type == tree.XML_COMMENT_NODE:
+ if state is not None:
+ return (state).comment_class
+ else:
+ return _Comment
+ elif c_node.type == tree.XML_ENTITY_REF_NODE:
+ if state is not None:
+ return (state).entity_class
+ else:
+ return _Entity
+ elif c_node.type == tree.XML_PI_NODE:
+ if state is None or (state).pi_class is None:
+ # special case XSLT-PI
+ if c_node.name is not NULL and c_node.content is not NULL:
+ if tree.xmlStrcmp(c_node.name, "xml-stylesheet") == 0:
+ if tree.xmlStrstr(c_node.content, "text/xsl") is not NULL or \
+ tree.xmlStrstr(c_node.content, "text/xml") is not NULL:
+ return _XSLTProcessingInstruction
+ return _ProcessingInstruction
+ else:
+ return (state).pi_class
+ else:
+ assert False, f"Unknown node type: {c_node.type}"
+
+
+################################################################################
+# attribute based lookup scheme
+
+cdef class AttributeBasedElementClassLookup(FallbackElementClassLookup):
+ """AttributeBasedElementClassLookup(self, attribute_name, class_mapping, fallback=None)
+ Checks an attribute of an Element and looks up the value in a
+ class dictionary.
+
+ Arguments:
+ - attribute name - '{ns}name' style string
+ - class mapping - Python dict mapping attribute values to Element classes
+ - fallback - optional fallback lookup mechanism
+
+ A None key in the class mapping will be checked if the attribute is
+ missing.
+ """
+ cdef object _class_mapping
+ cdef tuple _pytag
+ cdef const_xmlChar* _c_ns
+ cdef const_xmlChar* _c_name
+ def __cinit__(self):
+ self._lookup_function = _attribute_class_lookup
+
+ def __init__(self, attribute_name, class_mapping,
+ ElementClassLookup fallback=None):
+ self._pytag = _getNsTag(attribute_name)
+ ns, name = self._pytag
+ if ns is None:
+ self._c_ns = NULL
+ else:
+ self._c_ns = _xcstr(ns)
+ self._c_name = _xcstr(name)
+ self._class_mapping = dict(class_mapping)
+
+ FallbackElementClassLookup.__init__(self, fallback)
+
+cdef object _attribute_class_lookup(state, _Document doc, xmlNode* c_node):
+ cdef AttributeBasedElementClassLookup lookup
+ cdef python.PyObject* dict_result
+
+ lookup = state
+ if c_node.type == tree.XML_ELEMENT_NODE:
+ value = _attributeValueFromNsName(
+ c_node, lookup._c_ns, lookup._c_name)
+ dict_result = python.PyDict_GetItem(lookup._class_mapping, value)
+ if dict_result is not NULL:
+ cls = dict_result
+ _validateNodeClass(c_node, cls)
+ return cls
+ return _callLookupFallback(lookup, doc, c_node)
+
+
+################################################################################
+# per-parser lookup scheme
+
+cdef class ParserBasedElementClassLookup(FallbackElementClassLookup):
+ """ParserBasedElementClassLookup(self, fallback=None)
+ Element class lookup based on the XML parser.
+ """
+ def __cinit__(self):
+ self._lookup_function = _parser_class_lookup
+
+cdef object _parser_class_lookup(state, _Document doc, xmlNode* c_node):
+ if doc._parser._class_lookup is not None:
+ return doc._parser._class_lookup._lookup_function(
+ doc._parser._class_lookup, doc, c_node)
+ return _callLookupFallback(state, doc, c_node)
+
+
+################################################################################
+# custom class lookup based on node type, namespace, name
+
+cdef class CustomElementClassLookup(FallbackElementClassLookup):
+ """CustomElementClassLookup(self, fallback=None)
+ Element class lookup based on a subclass method.
+
+ You can inherit from this class and override the method::
+
+ lookup(self, type, doc, namespace, name)
+
+ to lookup the element class for a node. Arguments of the method:
+ * type: one of 'element', 'comment', 'PI', 'entity'
+ * doc: document that the node is in
+ * namespace: namespace URI of the node (or None for comments/PIs/entities)
+ * name: name of the element/entity, None for comments, target for PIs
+
+ If you return None from this method, the fallback will be called.
+ """
+ def __cinit__(self):
+ self._lookup_function = _custom_class_lookup
+
+ def lookup(self, type, doc, namespace, name):
+ "lookup(self, type, doc, namespace, name)"
+ return None
+
+cdef object _custom_class_lookup(state, _Document doc, xmlNode* c_node):
+ cdef CustomElementClassLookup lookup
+
+ lookup = state
+
+ if c_node.type == tree.XML_ELEMENT_NODE:
+ element_type = "element"
+ elif c_node.type == tree.XML_COMMENT_NODE:
+ element_type = "comment"
+ elif c_node.type == tree.XML_PI_NODE:
+ element_type = "PI"
+ elif c_node.type == tree.XML_ENTITY_REF_NODE:
+ element_type = "entity"
+ else:
+ element_type = "element"
+ if c_node.name is NULL:
+ name = None
+ else:
+ name = funicode(c_node.name)
+ c_str = tree._getNs(c_node)
+ ns = funicode(c_str) if c_str is not NULL else None
+
+ cls = lookup.lookup(element_type, doc, ns, name)
+ if cls is not None:
+ _validateNodeClass(c_node, cls)
+ return cls
+ return _callLookupFallback(lookup, doc, c_node)
+
+
+################################################################################
+# read-only tree based class lookup
+
+cdef class PythonElementClassLookup(FallbackElementClassLookup):
+ """PythonElementClassLookup(self, fallback=None)
+ Element class lookup based on a subclass method.
+
+ This class lookup scheme allows access to the entire XML tree in
+ read-only mode. To use it, re-implement the ``lookup(self, doc,
+ root)`` method in a subclass::
+
+ from lxml import etree, pyclasslookup
+
+ class MyElementClass(etree.ElementBase):
+ honkey = True
+
+ class MyLookup(pyclasslookup.PythonElementClassLookup):
+ def lookup(self, doc, root):
+ if root.tag == "sometag":
+ return MyElementClass
+ else:
+ for child in root:
+ if child.tag == "someothertag":
+ return MyElementClass
+ # delegate to default
+ return None
+
+ If you return None from this method, the fallback will be called.
+
+ The first argument is the opaque document instance that contains
+ the Element. The second argument is a lightweight Element proxy
+ implementation that is only valid during the lookup. Do not try
+ to keep a reference to it. Once the lookup is done, the proxy
+ will be invalid.
+
+ Also, you cannot wrap such a read-only Element in an ElementTree,
+ and you must take care not to keep a reference to them outside of
+ the `lookup()` method.
+
+ Note that the API of the Element objects is not complete. It is
+ purely read-only and does not support all features of the normal
+ `lxml.etree` API (such as XPath, extended slicing or some
+ iteration methods).
+
+ See https://lxml.de/element_classes.html
+ """
+ def __cinit__(self):
+ self._lookup_function = _python_class_lookup
+
+ def lookup(self, doc, element):
+ """lookup(self, doc, element)
+
+ Override this method to implement your own lookup scheme.
+ """
+ return None
+
+cdef object _python_class_lookup(state, _Document doc, tree.xmlNode* c_node):
+ cdef PythonElementClassLookup lookup
+ cdef _ReadOnlyProxy proxy
+ lookup = state
+
+ proxy = _newReadOnlyProxy(None, c_node)
+ cls = lookup.lookup(doc, proxy)
+ _freeReadOnlyProxies(proxy)
+
+ if cls is not None:
+ _validateNodeClass(c_node, cls)
+ return cls
+ return _callLookupFallback(lookup, doc, c_node)
+
+################################################################################
+# Global setup
+
+cdef _element_class_lookup_function LOOKUP_ELEMENT_CLASS
+cdef object ELEMENT_CLASS_LOOKUP_STATE
+
+cdef void _setElementClassLookupFunction(
+ _element_class_lookup_function function, object state):
+ global LOOKUP_ELEMENT_CLASS, ELEMENT_CLASS_LOOKUP_STATE
+ if function is NULL:
+ state = DEFAULT_ELEMENT_CLASS_LOOKUP
+ function = DEFAULT_ELEMENT_CLASS_LOOKUP._lookup_function
+
+ ELEMENT_CLASS_LOOKUP_STATE = state
+ LOOKUP_ELEMENT_CLASS = function
+
+def set_element_class_lookup(ElementClassLookup lookup = None):
+ """set_element_class_lookup(lookup = None)
+
+ Set the global element class lookup method.
+
+ This defines the main entry point for looking up element implementations.
+ The standard implementation uses the :class:`ParserBasedElementClassLookup`
+ to delegate to different lookup schemes for each parser.
+
+ .. warning::
+
+ This should only be changed by applications, not by library packages.
+ In most cases, parser specific lookups should be preferred,
+ which can be configured via
+ :meth:`~lxml.etree.XMLParser.set_element_class_lookup`
+ (and the same for HTML parsers).
+
+ Globally replacing the element class lookup by something other than a
+ :class:`ParserBasedElementClassLookup` will prevent parser specific lookup
+ schemes from working. Several tools rely on parser specific lookups,
+ including :mod:`lxml.html` and :mod:`lxml.objectify`.
+ """
+ if lookup is None or lookup._lookup_function is NULL:
+ _setElementClassLookupFunction(NULL, None)
+ else:
+ _setElementClassLookupFunction(lookup._lookup_function, lookup)
+
+# default setup: parser delegation
+cdef ParserBasedElementClassLookup DEFAULT_ELEMENT_CLASS_LOOKUP
+DEFAULT_ELEMENT_CLASS_LOOKUP = ParserBasedElementClassLookup()
+
+set_element_class_lookup(DEFAULT_ELEMENT_CLASS_LOOKUP)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/cleanup.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/cleanup.pxi
new file mode 100644
index 0000000000000000000000000000000000000000..8e266b33f0f3aef34f3448276abfb2cb8b1e4772
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/cleanup.pxi
@@ -0,0 +1,215 @@
+# functions for tree cleanup and removing elements from subtrees
+
+def cleanup_namespaces(tree_or_element, top_nsmap=None, keep_ns_prefixes=None):
+ """cleanup_namespaces(tree_or_element, top_nsmap=None, keep_ns_prefixes=None)
+
+ Remove all namespace declarations from a subtree that are not used
+ by any of the elements or attributes in that tree.
+
+ If a 'top_nsmap' is provided, it must be a mapping from prefixes
+ to namespace URIs. These namespaces will be declared on the top
+ element of the subtree before running the cleanup, which allows
+ moving namespace declarations to the top of the tree.
+
+ If a 'keep_ns_prefixes' is provided, it must be a list of prefixes.
+ These prefixes will not be removed as part of the cleanup.
+ """
+ element = _rootNodeOrRaise(tree_or_element)
+ c_element = element._c_node
+
+ if top_nsmap:
+ doc = element._doc
+ # declare namespaces from nsmap, then apply them to the subtree
+ _setNodeNamespaces(c_element, doc, None, top_nsmap)
+ moveNodeToDocument(doc, c_element.doc, c_element)
+
+ keep_ns_prefixes = (
+ set([_utf8(prefix) for prefix in keep_ns_prefixes])
+ if keep_ns_prefixes else None)
+
+ _removeUnusedNamespaceDeclarations(c_element, keep_ns_prefixes)
+
+
+def strip_attributes(tree_or_element, *attribute_names):
+ """strip_attributes(tree_or_element, *attribute_names)
+
+ Delete all attributes with the provided attribute names from an
+ Element (or ElementTree) and its descendants.
+
+ Attribute names can contain wildcards as in `_Element.iter`.
+
+ Example usage::
+
+ strip_attributes(root_element,
+ 'simpleattr',
+ '{http://some/ns}attrname',
+ '{http://other/ns}*')
+ """
+ cdef _MultiTagMatcher matcher
+ element = _rootNodeOrRaise(tree_or_element)
+ if not attribute_names:
+ return
+
+ matcher = _MultiTagMatcher.__new__(_MultiTagMatcher, attribute_names)
+ matcher.cacheTags(element._doc)
+ if matcher.rejectsAllAttributes():
+ return
+ _strip_attributes(element._c_node, matcher)
+
+
+cdef _strip_attributes(xmlNode* c_node, _MultiTagMatcher matcher):
+ cdef xmlAttr* c_attr
+ cdef xmlAttr* c_next_attr
+ tree.BEGIN_FOR_EACH_ELEMENT_FROM(c_node, c_node, 1)
+ if c_node.type == tree.XML_ELEMENT_NODE:
+ c_attr = c_node.properties
+ while c_attr is not NULL:
+ c_next_attr = c_attr.next
+ if matcher.matchesAttribute(c_attr):
+ tree.xmlRemoveProp(c_attr)
+ c_attr = c_next_attr
+ tree.END_FOR_EACH_ELEMENT_FROM(c_node)
+
+
+def strip_elements(tree_or_element, *tag_names, bint with_tail=True):
+ """strip_elements(tree_or_element, *tag_names, with_tail=True)
+
+ Delete all elements with the provided tag names from a tree or
+ subtree. This will remove the elements and their entire subtree,
+ including all their attributes, text content and descendants. It
+ will also remove the tail text of the element unless you
+ explicitly set the ``with_tail`` keyword argument option to False.
+
+ Tag names can contain wildcards as in `_Element.iter`.
+
+ Note that this will not delete the element (or ElementTree root
+ element) that you passed even if it matches. It will only treat
+ its descendants. If you want to include the root element, check
+ its tag name directly before even calling this function.
+
+ Example usage::
+
+ strip_elements(some_element,
+ 'simpletagname', # non-namespaced tag
+ '{http://some/ns}tagname', # namespaced tag
+ '{http://some/other/ns}*' # any tag from a namespace
+ lxml.etree.Comment # comments
+ )
+ """
+ cdef _MultiTagMatcher matcher
+ doc = _documentOrRaise(tree_or_element)
+ element = _rootNodeOrRaise(tree_or_element)
+ if not tag_names:
+ return
+
+ matcher = _MultiTagMatcher.__new__(_MultiTagMatcher, tag_names)
+ matcher.cacheTags(doc)
+ if matcher.rejectsAll():
+ return
+
+ if isinstance(tree_or_element, _ElementTree):
+ # include PIs and comments next to the root node
+ if matcher.matchesType(tree.XML_COMMENT_NODE):
+ _removeSiblings(element._c_node, tree.XML_COMMENT_NODE, with_tail)
+ if matcher.matchesType(tree.XML_PI_NODE):
+ _removeSiblings(element._c_node, tree.XML_PI_NODE, with_tail)
+ _strip_elements(doc, element._c_node, matcher, with_tail)
+
+cdef _strip_elements(_Document doc, xmlNode* c_node, _MultiTagMatcher matcher,
+ bint with_tail):
+ cdef xmlNode* c_child
+ cdef xmlNode* c_next
+
+ tree.BEGIN_FOR_EACH_ELEMENT_FROM(c_node, c_node, 1)
+ if c_node.type == tree.XML_ELEMENT_NODE:
+ # we run through the children here to prevent any problems
+ # with the tree iteration which would occur if we unlinked the
+ # c_node itself
+ c_child = _findChildForwards(c_node, 0)
+ while c_child is not NULL:
+ c_next = _nextElement(c_child)
+ if matcher.matches(c_child):
+ if c_child.type == tree.XML_ELEMENT_NODE:
+ if not with_tail:
+ tree.xmlUnlinkNode(c_child)
+ _removeNode(doc, c_child)
+ else:
+ if with_tail:
+ _removeText(c_child.next)
+ tree.xmlUnlinkNode(c_child)
+ attemptDeallocation(c_child)
+ c_child = c_next
+ tree.END_FOR_EACH_ELEMENT_FROM(c_node)
+
+
+def strip_tags(tree_or_element, *tag_names):
+ """strip_tags(tree_or_element, *tag_names)
+
+ Delete all elements with the provided tag names from a tree or
+ subtree. This will remove the elements and their attributes, but
+ *not* their text/tail content or descendants. Instead, it will
+ merge the text content and children of the element into its
+ parent.
+
+ Tag names can contain wildcards as in `_Element.iter`.
+
+ Note that this will not delete the element (or ElementTree root
+ element) that you passed even if it matches. It will only treat
+ its descendants.
+
+ Example usage::
+
+ strip_tags(some_element,
+ 'simpletagname', # non-namespaced tag
+ '{http://some/ns}tagname', # namespaced tag
+ '{http://some/other/ns}*' # any tag from a namespace
+ Comment # comments (including their text!)
+ )
+ """
+ cdef _MultiTagMatcher matcher
+ doc = _documentOrRaise(tree_or_element)
+ element = _rootNodeOrRaise(tree_or_element)
+ if not tag_names:
+ return
+
+ matcher = _MultiTagMatcher.__new__(_MultiTagMatcher, tag_names)
+ matcher.cacheTags(doc)
+ if matcher.rejectsAll():
+ return
+
+ if isinstance(tree_or_element, _ElementTree):
+ # include PIs and comments next to the root node
+ if matcher.matchesType(tree.XML_COMMENT_NODE):
+ _removeSiblings(element._c_node, tree.XML_COMMENT_NODE, 0)
+ if matcher.matchesType(tree.XML_PI_NODE):
+ _removeSiblings(element._c_node, tree.XML_PI_NODE, 0)
+ _strip_tags(doc, element._c_node, matcher)
+
+cdef _strip_tags(_Document doc, xmlNode* c_node, _MultiTagMatcher matcher):
+ cdef xmlNode* c_child
+ cdef xmlNode* c_next
+
+ tree.BEGIN_FOR_EACH_ELEMENT_FROM(c_node, c_node, 1)
+ if c_node.type == tree.XML_ELEMENT_NODE:
+ # we run through the children here to prevent any problems
+ # with the tree iteration which would occur if we unlinked the
+ # c_node itself
+ c_child = _findChildForwards(c_node, 0)
+ while c_child is not NULL:
+ if not matcher.matches(c_child):
+ c_child = _nextElement(c_child)
+ continue
+ if c_child.type == tree.XML_ELEMENT_NODE:
+ c_next = _findChildForwards(c_child, 0) or _nextElement(c_child)
+ _replaceNodeByChildren(doc, c_child)
+ if not attemptDeallocation(c_child):
+ if c_child.nsDef is not NULL:
+ # make namespaces absolute
+ moveNodeToDocument(doc, doc._c_doc, c_child)
+ c_child = c_next
+ else:
+ c_next = _nextElement(c_child)
+ tree.xmlUnlinkNode(c_child)
+ attemptDeallocation(c_child)
+ c_child = c_next
+ tree.END_FOR_EACH_ELEMENT_FROM(c_node)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/cssselect.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/cssselect.py
new file mode 100644
index 0000000000000000000000000000000000000000..54cd75ac9bfecdec7ea81e91b0840c6edd401515
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/cssselect.py
@@ -0,0 +1,101 @@
+"""CSS Selectors based on XPath.
+
+This module supports selecting XML/HTML tags based on CSS selectors.
+See the `CSSSelector` class for details.
+
+This is a thin wrapper around cssselect 0.7 or later.
+"""
+
+
+from . import etree
+try:
+ import cssselect as external_cssselect
+except ImportError:
+ raise ImportError(
+ 'cssselect does not seem to be installed. '
+ 'See https://pypi.org/project/cssselect/')
+
+
+SelectorSyntaxError = external_cssselect.SelectorSyntaxError
+ExpressionError = external_cssselect.ExpressionError
+SelectorError = external_cssselect.SelectorError
+
+
+__all__ = ['SelectorSyntaxError', 'ExpressionError', 'SelectorError',
+ 'CSSSelector']
+
+
+class LxmlTranslator(external_cssselect.GenericTranslator):
+ """
+ A custom CSS selector to XPath translator with lxml-specific extensions.
+ """
+ def xpath_contains_function(self, xpath, function):
+ # Defined there, removed in later drafts:
+ # http://www.w3.org/TR/2001/CR-css3-selectors-20011113/#content-selectors
+ if function.argument_types() not in (['STRING'], ['IDENT']):
+ raise ExpressionError(
+ "Expected a single string or ident for :contains(), got %r"
+ % function.arguments)
+ value = function.arguments[0].value
+ return xpath.add_condition(
+ 'contains(__lxml_internal_css:lower-case(string(.)), %s)'
+ % self.xpath_literal(value.lower()))
+
+
+class LxmlHTMLTranslator(LxmlTranslator, external_cssselect.HTMLTranslator):
+ """
+ lxml extensions + HTML support.
+ """
+
+
+def _make_lower_case(context, s):
+ return s.lower()
+
+ns = etree.FunctionNamespace('http://codespeak.net/lxml/css/')
+ns.prefix = '__lxml_internal_css'
+ns['lower-case'] = _make_lower_case
+
+
+class CSSSelector(etree.XPath):
+ """A CSS selector.
+
+ Usage::
+
+ >>> from lxml import etree, cssselect
+ >>> select = cssselect.CSSSelector("a tag > child")
+
+ >>> root = etree.XML("TEXT ")
+ >>> [ el.tag for el in select(root) ]
+ ['child']
+
+ To use CSS namespaces, you need to pass a prefix-to-namespace
+ mapping as ``namespaces`` keyword argument::
+
+ >>> rdfns = 'http://www.w3.org/1999/02/22-rdf-syntax-ns#'
+ >>> select_ns = cssselect.CSSSelector('root > rdf|Description',
+ ... namespaces={'rdf': rdfns})
+
+ >>> rdf = etree.XML((
+ ... ''
+ ... 'blah '
+ ... ' ') % rdfns)
+ >>> [(el.tag, el.text) for el in select_ns(rdf)]
+ [('{http://www.w3.org/1999/02/22-rdf-syntax-ns#}Description', 'blah')]
+
+ """
+ def __init__(self, css, namespaces=None, translator='xml'):
+ if translator == 'xml':
+ translator = LxmlTranslator()
+ elif translator == 'html':
+ translator = LxmlHTMLTranslator()
+ elif translator == 'xhtml':
+ translator = LxmlHTMLTranslator(xhtml=True)
+ path = translator.css_to_xpath(css)
+ super().__init__(path, namespaces=namespaces)
+ self.css = css
+
+ def __repr__(self):
+ return '<%s %x for %r>' % (
+ self.__class__.__name__,
+ abs(id(self)),
+ self.css)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/doctestcompare.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/doctestcompare.py
new file mode 100644
index 0000000000000000000000000000000000000000..8099771de906a37ed007c779f152fe96f182060d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/doctestcompare.py
@@ -0,0 +1,488 @@
+"""
+lxml-based doctest output comparison.
+
+Note: normally, you should just import the `lxml.usedoctest` and
+`lxml.html.usedoctest` modules from within a doctest, instead of this
+one::
+
+ >>> import lxml.usedoctest # for XML output
+
+ >>> import lxml.html.usedoctest # for HTML output
+
+To use this module directly, you must call ``lxmldoctest.install()``,
+which will cause doctest to use this in all subsequent calls.
+
+This changes the way output is checked and comparisons are made for
+XML or HTML-like content.
+
+XML or HTML content is noticed because the example starts with ``<``
+(it's HTML if it starts with ```` or include an ``any``
+attribute in the tag. An ``any`` tag matches any tag, while the
+attribute matches any and all attributes.
+
+When a match fails, the reformatted example and gotten text is
+displayed (indented), and a rough diff-like output is given. Anything
+marked with ``+`` is in the output but wasn't supposed to be, and
+similarly ``-`` means its in the example but wasn't in the output.
+
+You can disable parsing on one line with ``# doctest:+NOPARSE_MARKUP``
+"""
+
+from lxml import etree
+import sys
+import re
+import doctest
+try:
+ from html import escape as html_escape
+except ImportError:
+ from cgi import escape as html_escape
+
+__all__ = ['PARSE_HTML', 'PARSE_XML', 'NOPARSE_MARKUP', 'LXMLOutputChecker',
+ 'LHTMLOutputChecker', 'install', 'temp_install']
+
+PARSE_HTML = doctest.register_optionflag('PARSE_HTML')
+PARSE_XML = doctest.register_optionflag('PARSE_XML')
+NOPARSE_MARKUP = doctest.register_optionflag('NOPARSE_MARKUP')
+
+OutputChecker = doctest.OutputChecker
+
+def strip(v):
+ if v is None:
+ return None
+ else:
+ return v.strip()
+
+def norm_whitespace(v):
+ return _norm_whitespace_re.sub(' ', v)
+
+_html_parser = etree.HTMLParser(recover=False, remove_blank_text=True)
+
+def html_fromstring(html):
+ return etree.fromstring(html, _html_parser)
+
+# We use this to distinguish repr()s from elements:
+_repr_re = re.compile(r'^<[^>]+ (at|object) ')
+_norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
+
+class LXMLOutputChecker(OutputChecker):
+
+ empty_tags = (
+ 'param', 'img', 'area', 'br', 'basefont', 'input',
+ 'base', 'meta', 'link', 'col')
+
+ def get_default_parser(self):
+ return etree.XML
+
+ def check_output(self, want, got, optionflags):
+ alt_self = getattr(self, '_temp_override_self', None)
+ if alt_self is not None:
+ super_method = self._temp_call_super_check_output
+ self = alt_self
+ else:
+ super_method = OutputChecker.check_output
+ parser = self.get_parser(want, got, optionflags)
+ if not parser:
+ return super_method(
+ self, want, got, optionflags)
+ try:
+ want_doc = parser(want)
+ except etree.XMLSyntaxError:
+ return False
+ try:
+ got_doc = parser(got)
+ except etree.XMLSyntaxError:
+ return False
+ return self.compare_docs(want_doc, got_doc)
+
+ def get_parser(self, want, got, optionflags):
+ parser = None
+ if NOPARSE_MARKUP & optionflags:
+ return None
+ if PARSE_HTML & optionflags:
+ parser = html_fromstring
+ elif PARSE_XML & optionflags:
+ parser = etree.XML
+ elif (want.strip().lower().startswith('' % el.tag
+ return '<%s %s>' % (el.tag, ' '.join(attrs))
+
+ def format_end_tag(self, el):
+ if isinstance(el, etree.CommentBase):
+ # FIXME: probably PIs should be handled specially too?
+ return '-->'
+ return '%s>' % el.tag
+
+ def collect_diff(self, want, got, html, indent):
+ parts = []
+ if not len(want) and not len(got):
+ parts.append(' '*indent)
+ parts.append(self.collect_diff_tag(want, got))
+ if not self.html_empty_tag(got, html):
+ parts.append(self.collect_diff_text(want.text, got.text))
+ parts.append(self.collect_diff_end_tag(want, got))
+ parts.append(self.collect_diff_text(want.tail, got.tail))
+ parts.append('\n')
+ return ''.join(parts)
+ parts.append(' '*indent)
+ parts.append(self.collect_diff_tag(want, got))
+ parts.append('\n')
+ if strip(want.text) or strip(got.text):
+ parts.append(' '*indent)
+ parts.append(self.collect_diff_text(want.text, got.text))
+ parts.append('\n')
+ want_children = list(want)
+ got_children = list(got)
+ while want_children or got_children:
+ if not want_children:
+ parts.append(self.format_doc(got_children.pop(0), html, indent+2, '+'))
+ continue
+ if not got_children:
+ parts.append(self.format_doc(want_children.pop(0), html, indent+2, '-'))
+ continue
+ parts.append(self.collect_diff(
+ want_children.pop(0), got_children.pop(0), html, indent+2))
+ parts.append(' '*indent)
+ parts.append(self.collect_diff_end_tag(want, got))
+ parts.append('\n')
+ if strip(want.tail) or strip(got.tail):
+ parts.append(' '*indent)
+ parts.append(self.collect_diff_text(want.tail, got.tail))
+ parts.append('\n')
+ return ''.join(parts)
+
+ def collect_diff_tag(self, want, got):
+ if not self.tag_compare(want.tag, got.tag):
+ tag = '%s (got: %s)' % (want.tag, got.tag)
+ else:
+ tag = got.tag
+ attrs = []
+ any = want.tag == 'any' or 'any' in want.attrib
+ for name, value in sorted(got.attrib.items()):
+ if name not in want.attrib and not any:
+ attrs.append('+%s="%s"' % (name, self.format_text(value, False)))
+ else:
+ if name in want.attrib:
+ text = self.collect_diff_text(want.attrib[name], value, False)
+ else:
+ text = self.format_text(value, False)
+ attrs.append('%s="%s"' % (name, text))
+ if not any:
+ for name, value in sorted(want.attrib.items()):
+ if name in got.attrib:
+ continue
+ attrs.append('-%s="%s"' % (name, self.format_text(value, False)))
+ if attrs:
+ tag = '<%s %s>' % (tag, ' '.join(attrs))
+ else:
+ tag = '<%s>' % tag
+ return tag
+
+ def collect_diff_end_tag(self, want, got):
+ if want.tag != got.tag:
+ tag = '%s (got: %s)' % (want.tag, got.tag)
+ else:
+ tag = got.tag
+ return '%s>' % tag
+
+ def collect_diff_text(self, want, got, strip=True):
+ if self.text_compare(want, got, strip):
+ if not got:
+ return ''
+ return self.format_text(got, strip)
+ text = '%s (got: %s)' % (want, got)
+ return self.format_text(text, strip)
+
+class LHTMLOutputChecker(LXMLOutputChecker):
+ def get_default_parser(self):
+ return html_fromstring
+
+def install(html=False):
+ """
+ Install doctestcompare for all future doctests.
+
+ If html is true, then by default the HTML parser will be used;
+ otherwise the XML parser is used.
+ """
+ if html:
+ doctest.OutputChecker = LHTMLOutputChecker
+ else:
+ doctest.OutputChecker = LXMLOutputChecker
+
+def temp_install(html=False, del_module=None):
+ """
+ Use this *inside* a doctest to enable this checker for this
+ doctest only.
+
+ If html is true, then by default the HTML parser will be used;
+ otherwise the XML parser is used.
+ """
+ if html:
+ Checker = LHTMLOutputChecker
+ else:
+ Checker = LXMLOutputChecker
+ frame = _find_doctest_frame()
+ dt_self = frame.f_locals['self']
+ checker = Checker()
+ old_checker = dt_self._checker
+ dt_self._checker = checker
+ # The unfortunate thing is that there is a local variable 'check'
+ # in the function that runs the doctests, that is a bound method
+ # into the output checker. We have to update that. We can't
+ # modify the frame, so we have to modify the object in place. The
+ # only way to do this is to actually change the func_code
+ # attribute of the method. We change it, and then wait for
+ # __record_outcome to be run, which signals the end of the __run
+ # method, at which point we restore the previous check_output
+ # implementation.
+ check_func = frame.f_locals['check'].__func__
+ checker_check_func = checker.check_output.__func__
+ # Because we can't patch up func_globals, this is the only global
+ # in check_output that we care about:
+ doctest.etree = etree
+ _RestoreChecker(dt_self, old_checker, checker,
+ check_func, checker_check_func,
+ del_module)
+
+class _RestoreChecker:
+ def __init__(self, dt_self, old_checker, new_checker, check_func, clone_func,
+ del_module):
+ self.dt_self = dt_self
+ self.checker = old_checker
+ self.checker._temp_call_super_check_output = self.call_super
+ self.checker._temp_override_self = new_checker
+ self.check_func = check_func
+ self.clone_func = clone_func
+ self.del_module = del_module
+ self.install_clone()
+ self.install_dt_self()
+ def install_clone(self):
+ self.func_code = self.check_func.__code__
+ self.func_globals = self.check_func.__globals__
+ self.check_func.__code__ = self.clone_func.__code__
+ def uninstall_clone(self):
+ self.check_func.__code__ = self.func_code
+ def install_dt_self(self):
+ self.prev_func = self.dt_self._DocTestRunner__record_outcome
+ self.dt_self._DocTestRunner__record_outcome = self
+ def uninstall_dt_self(self):
+ self.dt_self._DocTestRunner__record_outcome = self.prev_func
+ def uninstall_module(self):
+ if self.del_module:
+ import sys
+ del sys.modules[self.del_module]
+ if '.' in self.del_module:
+ package, module = self.del_module.rsplit('.', 1)
+ package_mod = sys.modules[package]
+ delattr(package_mod, module)
+ def __call__(self, *args, **kw):
+ self.uninstall_clone()
+ self.uninstall_dt_self()
+ del self.checker._temp_override_self
+ del self.checker._temp_call_super_check_output
+ result = self.prev_func(*args, **kw)
+ self.uninstall_module()
+ return result
+ def call_super(self, *args, **kw):
+ self.uninstall_clone()
+ try:
+ return self.check_func(*args, **kw)
+ finally:
+ self.install_clone()
+
+def _find_doctest_frame():
+ import sys
+ frame = sys._getframe(1)
+ while frame:
+ l = frame.f_locals
+ if 'BOOM' in l:
+ # Sign of doctest
+ return frame
+ frame = frame.f_back
+ raise LookupError(
+ "Could not find doctest (only use this function *inside* a doctest)")
+
+__test__ = {
+ 'basic': '''
+ >>> temp_install()
+ >>> print """stuff """
+ ...
+ >>> print """ """
+
+
+
+ >>> print """blahblahblah """ # doctest: +NOPARSE_MARKUP, +ELLIPSIS
+ ...foo />
+ '''}
+
+if __name__ == '__main__':
+ import doctest
+ doctest.testmod()
+
+
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/etree.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/etree.h
new file mode 100644
index 0000000000000000000000000000000000000000..17b99a7be5c4159429d575c9e98f621f57c8310c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/etree.h
@@ -0,0 +1,244 @@
+/* Generated by Cython 3.1.4 */
+
+#ifndef __PYX_HAVE__lxml__etree
+#define __PYX_HAVE__lxml__etree
+
+#include "Python.h"
+struct LxmlDocument;
+struct LxmlElement;
+struct LxmlElementTree;
+struct LxmlElementTagMatcher;
+struct LxmlElementIterator;
+struct LxmlElementBase;
+struct LxmlElementClassLookup;
+struct LxmlFallbackElementClassLookup;
+
+/* "lxml/etree.pyx":451
+ *
+ * # type of a function that steps from node to node
+ * ctypedef public xmlNode* (*_node_to_node_function)(xmlNode*) # <<<<<<<<<<<<<<
+ *
+ *
+*/
+typedef xmlNode *(*_node_to_node_function)(xmlNode *);
+
+/* "lxml/etree.pyx":465
+ * # Public Python API
+ *
+ * @cython.final # <<<<<<<<<<<<<<
+ * @cython.freelist(8)
+ * cdef public class _Document [ type LxmlDocumentType, object LxmlDocument ]:
+*/
+struct LxmlDocument {
+ PyObject_HEAD
+ struct __pyx_vtabstruct_4lxml_5etree__Document *__pyx_vtab;
+ int _ns_counter;
+ PyObject *_prefix_tail;
+ xmlDoc *_c_doc;
+ struct __pyx_obj_4lxml_5etree__BaseParser *_parser;
+};
+
+/* "lxml/etree.pyx":817
+ *
+ *
+ * @cython.no_gc_clear # <<<<<<<<<<<<<<
+ * cdef public class _Element [ type LxmlElementType, object LxmlElement ]:
+ * """Element class.
+*/
+struct LxmlElement {
+ PyObject_HEAD
+ struct LxmlDocument *_doc;
+ xmlNode *_c_node;
+ PyObject *_tag;
+};
+
+/* "lxml/etree.pyx":1991
+ *
+ *
+ * cdef public class _ElementTree [ type LxmlElementTreeType, # <<<<<<<<<<<<<<
+ * object LxmlElementTree ]:
+ * cdef _Document _doc
+*/
+struct LxmlElementTree {
+ PyObject_HEAD
+ struct __pyx_vtabstruct_4lxml_5etree__ElementTree *__pyx_vtab;
+ struct LxmlDocument *_doc;
+ struct LxmlElement *_context_node;
+};
+
+/* "lxml/etree.pyx":2765
+ *
+ *
+ * cdef public class _ElementTagMatcher [ object LxmlElementTagMatcher, # <<<<<<<<<<<<<<
+ * type LxmlElementTagMatcherType ]:
+ * """
+*/
+struct LxmlElementTagMatcher {
+ PyObject_HEAD
+ struct __pyx_vtabstruct_4lxml_5etree__ElementTagMatcher *__pyx_vtab;
+ PyObject *_pystrings;
+ int _node_type;
+ char *_href;
+ char *_name;
+};
+
+/* "lxml/etree.pyx":2796
+ * self._name = NULL
+ *
+ * cdef public class _ElementIterator(_ElementTagMatcher) [ # <<<<<<<<<<<<<<
+ * object LxmlElementIterator, type LxmlElementIteratorType ]:
+ * """
+*/
+struct LxmlElementIterator {
+ struct LxmlElementTagMatcher __pyx_base;
+ struct LxmlElement *_node;
+ _node_to_node_function _next_element;
+};
+
+/* "src/lxml/classlookup.pxi":6
+ * # Custom Element classes
+ *
+ * cdef public class ElementBase(_Element) [ type LxmlElementBaseType, # <<<<<<<<<<<<<<
+ * object LxmlElementBase ]:
+ * """ElementBase(*children, attrib=None, nsmap=None, **_extra)
+*/
+struct LxmlElementBase {
+ struct LxmlElement __pyx_base;
+};
+
+/* "src/lxml/classlookup.pxi":210
+ * # Element class lookup
+ *
+ * ctypedef public object (*_element_class_lookup_function)(object, _Document, xmlNode*) # <<<<<<<<<<<<<<
+ *
+ * # class to store element class lookup functions
+*/
+typedef PyObject *(*_element_class_lookup_function)(PyObject *, struct LxmlDocument *, xmlNode *);
+
+/* "src/lxml/classlookup.pxi":213
+ *
+ * # class to store element class lookup functions
+ * cdef public class ElementClassLookup [ type LxmlElementClassLookupType, # <<<<<<<<<<<<<<
+ * object LxmlElementClassLookup ]:
+ * """ElementClassLookup(self)
+*/
+struct LxmlElementClassLookup {
+ PyObject_HEAD
+ _element_class_lookup_function _lookup_function;
+};
+
+/* "src/lxml/classlookup.pxi":221
+ *
+ *
+ * cdef public class FallbackElementClassLookup(ElementClassLookup) \ # <<<<<<<<<<<<<<
+ * [ type LxmlFallbackElementClassLookupType,
+ * object LxmlFallbackElementClassLookup ]:
+*/
+struct LxmlFallbackElementClassLookup {
+ struct LxmlElementClassLookup __pyx_base;
+ struct __pyx_vtabstruct_4lxml_5etree_FallbackElementClassLookup *__pyx_vtab;
+ struct LxmlElementClassLookup *fallback;
+ _element_class_lookup_function _fallback_function;
+};
+
+#ifndef __PYX_HAVE_API__lxml__etree
+
+#ifdef CYTHON_EXTERN_C
+ #undef __PYX_EXTERN_C
+ #define __PYX_EXTERN_C CYTHON_EXTERN_C
+#elif defined(__PYX_EXTERN_C)
+ #ifdef _MSC_VER
+ #pragma message ("Please do not define the '__PYX_EXTERN_C' macro externally. Use 'CYTHON_EXTERN_C' instead.")
+ #else
+ #warning Please do not define the '__PYX_EXTERN_C' macro externally. Use 'CYTHON_EXTERN_C' instead.
+ #endif
+#else
+ #ifdef __cplusplus
+ #define __PYX_EXTERN_C extern "C"
+ #else
+ #define __PYX_EXTERN_C extern
+ #endif
+#endif
+
+#ifndef DL_IMPORT
+ #define DL_IMPORT(_T) _T
+#endif
+
+__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlDocumentType;
+__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementType;
+__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementTreeType;
+__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementTagMatcherType;
+__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementIteratorType;
+__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementBaseType;
+__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementClassLookupType;
+__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlFallbackElementClassLookupType;
+
+__PYX_EXTERN_C struct LxmlElement *deepcopyNodeToDocument(struct LxmlDocument *, xmlNode *);
+__PYX_EXTERN_C struct LxmlElementTree *elementTreeFactory(struct LxmlElement *);
+__PYX_EXTERN_C struct LxmlElementTree *newElementTree(struct LxmlElement *, PyObject *);
+__PYX_EXTERN_C struct LxmlElementTree *adoptExternalDocument(xmlDoc *, PyObject *, int);
+__PYX_EXTERN_C struct LxmlElement *elementFactory(struct LxmlDocument *, xmlNode *);
+__PYX_EXTERN_C struct LxmlElement *makeElement(PyObject *, struct LxmlDocument *, PyObject *, PyObject *, PyObject *, PyObject *, PyObject *);
+__PYX_EXTERN_C struct LxmlElement *makeSubElement(struct LxmlElement *, PyObject *, PyObject *, PyObject *, PyObject *, PyObject *);
+__PYX_EXTERN_C void setElementClassLookupFunction(_element_class_lookup_function, PyObject *);
+__PYX_EXTERN_C PyObject *lookupDefaultElementClass(PyObject *, PyObject *, xmlNode *);
+__PYX_EXTERN_C PyObject *lookupNamespaceElementClass(PyObject *, PyObject *, xmlNode *);
+__PYX_EXTERN_C PyObject *callLookupFallback(struct LxmlFallbackElementClassLookup *, struct LxmlDocument *, xmlNode *);
+__PYX_EXTERN_C int tagMatches(xmlNode *, const xmlChar *, const xmlChar *);
+__PYX_EXTERN_C struct LxmlDocument *documentOrRaise(PyObject *);
+__PYX_EXTERN_C struct LxmlElement *rootNodeOrRaise(PyObject *);
+__PYX_EXTERN_C int hasText(xmlNode *);
+__PYX_EXTERN_C int hasTail(xmlNode *);
+__PYX_EXTERN_C PyObject *textOf(xmlNode *);
+__PYX_EXTERN_C PyObject *tailOf(xmlNode *);
+__PYX_EXTERN_C int setNodeText(xmlNode *, PyObject *);
+__PYX_EXTERN_C int setTailText(xmlNode *, PyObject *);
+__PYX_EXTERN_C PyObject *attributeValue(xmlNode *, xmlAttr *);
+__PYX_EXTERN_C PyObject *attributeValueFromNsName(xmlNode *, const xmlChar *, const xmlChar *);
+__PYX_EXTERN_C PyObject *getAttributeValue(struct LxmlElement *, PyObject *, PyObject *);
+__PYX_EXTERN_C PyObject *iterattributes(struct LxmlElement *, int);
+__PYX_EXTERN_C PyObject *collectAttributes(xmlNode *, int);
+__PYX_EXTERN_C int setAttributeValue(struct LxmlElement *, PyObject *, PyObject *);
+__PYX_EXTERN_C int delAttribute(struct LxmlElement *, PyObject *);
+__PYX_EXTERN_C int delAttributeFromNsName(xmlNode *, const xmlChar *, const xmlChar *);
+__PYX_EXTERN_C int hasChild(xmlNode *);
+__PYX_EXTERN_C xmlNode *findChild(xmlNode *, Py_ssize_t);
+__PYX_EXTERN_C xmlNode *findChildForwards(xmlNode *, Py_ssize_t);
+__PYX_EXTERN_C xmlNode *findChildBackwards(xmlNode *, Py_ssize_t);
+__PYX_EXTERN_C xmlNode *nextElement(xmlNode *);
+__PYX_EXTERN_C xmlNode *previousElement(xmlNode *);
+__PYX_EXTERN_C void appendChild(struct LxmlElement *, struct LxmlElement *);
+__PYX_EXTERN_C int appendChildToElement(struct LxmlElement *, struct LxmlElement *);
+__PYX_EXTERN_C PyObject *pyunicode(const xmlChar *);
+__PYX_EXTERN_C PyObject *utf8(PyObject *);
+__PYX_EXTERN_C PyObject *getNsTag(PyObject *);
+__PYX_EXTERN_C PyObject *getNsTagWithEmptyNs(PyObject *);
+__PYX_EXTERN_C PyObject *namespacedName(xmlNode *);
+__PYX_EXTERN_C PyObject *namespacedNameFromNsName(const xmlChar *, const xmlChar *);
+__PYX_EXTERN_C void iteratorStoreNext(struct LxmlElementIterator *, struct LxmlElement *);
+__PYX_EXTERN_C void initTagMatch(struct LxmlElementTagMatcher *, PyObject *);
+__PYX_EXTERN_C xmlNs *findOrBuildNodeNsPrefix(struct LxmlDocument *, xmlNode *, const xmlChar *, const xmlChar *);
+
+#endif /* !__PYX_HAVE_API__lxml__etree */
+
+/* WARNING: the interface of the module init function changed in CPython 3.5. */
+/* It now returns a PyModuleDef instance instead of a PyModule instance. */
+
+/* WARNING: Use PyImport_AppendInittab("etree", PyInit_etree) instead of calling PyInit_etree directly from Python 3.5 */
+PyMODINIT_FUNC PyInit_etree(void);
+
+#if PY_VERSION_HEX >= 0x03050000 && (defined(__GNUC__) || defined(__clang__) || defined(_MSC_VER) || (defined(__cplusplus) && __cplusplus >= 201402L))
+#if defined(__cplusplus) && __cplusplus >= 201402L
+[[deprecated("Use PyImport_AppendInittab(\"etree\", PyInit_etree) instead of calling PyInit_etree directly.")]] inline
+#elif defined(__GNUC__) || defined(__clang__)
+__attribute__ ((__deprecated__("Use PyImport_AppendInittab(\"etree\", PyInit_etree) instead of calling PyInit_etree directly."), __unused__)) __inline__
+#elif defined(_MSC_VER)
+__declspec(deprecated("Use PyImport_AppendInittab(\"etree\", PyInit_etree) instead of calling PyInit_etree directly.")) __inline
+#endif
+static PyObject* __PYX_WARN_IF_PyInit_etree_INIT_CALLED(PyObject* res) {
+ return res;
+}
+#define PyInit_etree() __PYX_WARN_IF_PyInit_etree_INIT_CALLED(PyInit_etree())
+#endif
+
+#endif /* !__PYX_HAVE__lxml__etree */
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/etree.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/etree.pyx
new file mode 100644
index 0000000000000000000000000000000000000000..562d95ed167945504fd50182824340a5931c4b10
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/etree.pyx
@@ -0,0 +1,3853 @@
+# cython: binding=True
+# cython: auto_pickle=False
+# cython: language_level=3
+
+"""
+The ``lxml.etree`` module implements the extended ElementTree API for XML.
+"""
+
+__docformat__ = "restructuredtext en"
+
+__all__ = [
+ 'AttributeBasedElementClassLookup', 'C14NError', 'C14NWriterTarget', 'CDATA',
+ 'Comment', 'CommentBase', 'CustomElementClassLookup', 'DEBUG',
+ 'DTD', 'DTDError', 'DTDParseError', 'DTDValidateError',
+ 'DocumentInvalid', 'ETCompatXMLParser', 'ETXPath', 'Element',
+ 'ElementBase', 'ElementClassLookup', 'ElementDefaultClassLookup',
+ 'ElementNamespaceClassLookup', 'ElementTree', 'Entity', 'EntityBase',
+ 'Error', 'ErrorDomains', 'ErrorLevels', 'ErrorTypes', 'Extension',
+ 'FallbackElementClassLookup', 'FunctionNamespace', 'HTML', 'HTMLParser',
+ 'ICONV_COMPILED_VERSION',
+ 'LIBXML_COMPILED_VERSION', 'LIBXML_VERSION',
+ 'LIBXML_FEATURES',
+ 'LIBXSLT_COMPILED_VERSION', 'LIBXSLT_VERSION',
+ 'LXML_VERSION',
+ 'LxmlError', 'LxmlRegistryError', 'LxmlSyntaxError',
+ 'NamespaceRegistryError', 'PI', 'PIBase', 'ParseError',
+ 'ParserBasedElementClassLookup', 'ParserError', 'ProcessingInstruction',
+ 'PyErrorLog', 'PythonElementClassLookup', 'QName', 'RelaxNG',
+ 'RelaxNGError', 'RelaxNGErrorTypes', 'RelaxNGParseError',
+ 'RelaxNGValidateError', 'Resolver', 'Schematron', 'SchematronError',
+ 'SchematronParseError', 'SchematronValidateError', 'SerialisationError',
+ 'SubElement', 'TreeBuilder', 'XInclude', 'XIncludeError', 'XML',
+ 'XMLDTDID', 'XMLID', 'XMLParser', 'XMLSchema', 'XMLSchemaError',
+ 'XMLSchemaParseError', 'XMLSchemaValidateError', 'XMLSyntaxError',
+ 'XMLTreeBuilder', 'XPath', 'XPathDocumentEvaluator', 'XPathError',
+ 'XPathEvalError', 'XPathEvaluator', 'XPathFunctionError', 'XPathResultError',
+ 'XPathSyntaxError', 'XSLT', 'XSLTAccessControl', 'XSLTApplyError',
+ 'XSLTError', 'XSLTExtension', 'XSLTExtensionError', 'XSLTParseError',
+ 'XSLTSaveError', 'canonicalize',
+ 'cleanup_namespaces', 'clear_error_log', 'dump',
+ 'fromstring', 'fromstringlist', 'get_default_parser', 'iselement',
+ 'iterparse', 'iterwalk', 'parse', 'parseid', 'register_namespace',
+ 'set_default_parser', 'set_element_class_lookup', 'strip_attributes',
+ 'strip_elements', 'strip_tags', 'tostring', 'tostringlist', 'tounicode',
+ 'use_global_python_log'
+ ]
+
+cimport cython
+
+from lxml cimport python
+from lxml.includes cimport tree, config
+from lxml.includes.tree cimport xmlDoc, xmlNode, xmlAttr, xmlNs, _isElement, _getNs
+from lxml.includes.tree cimport const_xmlChar, xmlChar, _xcstr
+from lxml.python cimport _cstr, _isString
+from lxml.includes cimport xpath
+from lxml.includes cimport c14n
+
+# Cython's standard declarations
+cimport cpython.mem
+cimport cpython.ref
+from libc cimport limits, stdio, stdlib
+from libc cimport string as cstring_h # not to be confused with stdlib 'string'
+from libc.string cimport const_char
+
+cdef object os_path_abspath
+from os.path import abspath as os_path_abspath
+
+cdef object BytesIO, StringIO
+from io import BytesIO, StringIO
+
+cdef object OrderedDict
+from collections import OrderedDict
+
+cdef object _elementpath
+from lxml import _elementpath
+
+cdef object sys
+import sys
+
+cdef object re
+import re
+
+cdef object partial
+from functools import partial
+
+cdef object islice
+from itertools import islice
+
+cdef object ITER_EMPTY = iter(())
+
+cdef object MutableMapping
+from collections.abc import MutableMapping
+
+class _ImmutableMapping(MutableMapping):
+ def __getitem__(self, key):
+ raise KeyError, key
+
+ def __setitem__(self, key, value):
+ raise KeyError, key
+
+ def __delitem__(self, key):
+ raise KeyError, key
+
+ def __contains__(self, key):
+ return False
+
+ def __len__(self):
+ return 0
+
+ def __iter__(self):
+ return ITER_EMPTY
+ iterkeys = itervalues = iteritems = __iter__
+
+cdef object IMMUTABLE_EMPTY_MAPPING = _ImmutableMapping()
+del _ImmutableMapping
+
+
+# the rules
+# ---------
+# any libxml C argument/variable is prefixed with c_
+# any non-public function/class is prefixed with an underscore
+# instance creation is always through factories
+
+# what to do with libxml2/libxslt error messages?
+# 0 : drop
+# 1 : use log
+DEF __DEBUG = 1
+
+# maximum number of lines in the libxml2/xslt log if __DEBUG == 1
+DEF __MAX_LOG_SIZE = 100
+
+# make the compiled-in debug state publicly available
+DEBUG = __DEBUG
+
+# A struct to store a cached qualified tag name+href pair.
+# While we can borrow the c_name from the document dict,
+# PyPy requires us to store a Python reference for the
+# namespace in order to keep the byte buffer alive.
+cdef struct qname:
+ const_xmlChar* c_name
+ python.PyObject* href
+
+# initialize parser (and threading)
+xmlparser.xmlInitParser()
+
+# global per-thread setup
+tree.xmlThrDefIndentTreeOutput(1)
+tree.xmlThrDefLineNumbersDefaultValue(1)
+
+_initThreadLogging()
+
+# filename encoding
+cdef bytes _FILENAME_ENCODING = (sys.getfilesystemencoding() or sys.getdefaultencoding() or 'ascii').encode("UTF-8")
+cdef char* _C_FILENAME_ENCODING = _cstr(_FILENAME_ENCODING)
+
+# set up some default namespace prefixes
+cdef dict _DEFAULT_NAMESPACE_PREFIXES = {
+ b"http://www.w3.org/XML/1998/namespace": b'xml',
+ b"http://www.w3.org/1999/xhtml": b"html",
+ b"http://www.w3.org/1999/XSL/Transform": b"xsl",
+ b"http://www.w3.org/1999/02/22-rdf-syntax-ns#": b"rdf",
+ b"http://schemas.xmlsoap.org/wsdl/": b"wsdl",
+ # xml schema
+ b"http://www.w3.org/2001/XMLSchema": b"xs",
+ b"http://www.w3.org/2001/XMLSchema-instance": b"xsi",
+ # dublin core
+ b"http://purl.org/dc/elements/1.1/": b"dc",
+ # objectify
+ b"http://codespeak.net/lxml/objectify/pytype" : b"py",
+}
+
+# To avoid runtime encoding overhead, we keep a Unicode copy
+# of the uri-prefix mapping as (str, str) items view.
+cdef object _DEFAULT_NAMESPACE_PREFIXES_ITEMS = []
+
+cdef _update_default_namespace_prefixes_items():
+ cdef bytes ns, prefix
+ global _DEFAULT_NAMESPACE_PREFIXES_ITEMS
+ _DEFAULT_NAMESPACE_PREFIXES_ITEMS = {
+ ns.decode('utf-8') : prefix.decode('utf-8')
+ for ns, prefix in _DEFAULT_NAMESPACE_PREFIXES.items()
+ }.items()
+
+_update_default_namespace_prefixes_items()
+
+cdef object _check_internal_prefix = re.compile(br"ns\d+$").match
+
+def register_namespace(prefix, uri):
+ """Registers a namespace prefix that newly created Elements in that
+ namespace will use. The registry is global, and any existing
+ mapping for either the given prefix or the namespace URI will be
+ removed.
+ """
+ prefix_utf, uri_utf = _utf8(prefix), _utf8(uri)
+ if _check_internal_prefix(prefix_utf):
+ raise ValueError("Prefix format reserved for internal use")
+ _tagValidOrRaise(prefix_utf)
+ _uriValidOrRaise(uri_utf)
+ if (uri_utf == b"http://www.w3.org/XML/1998/namespace" and prefix_utf != b'xml'
+ or prefix_utf == b'xml' and uri_utf != b"http://www.w3.org/XML/1998/namespace"):
+ raise ValueError("Cannot change the 'xml' prefix of the XML namespace")
+ for k, v in list(_DEFAULT_NAMESPACE_PREFIXES.items()):
+ if k == uri_utf or v == prefix_utf:
+ del _DEFAULT_NAMESPACE_PREFIXES[k]
+ _DEFAULT_NAMESPACE_PREFIXES[uri_utf] = prefix_utf
+ _update_default_namespace_prefixes_items()
+
+
+# Error superclass for ElementTree compatibility
+cdef class Error(Exception):
+ pass
+
+# module level superclass for all exceptions
+cdef class LxmlError(Error):
+ """Main exception base class for lxml. All other exceptions inherit from
+ this one.
+ """
+ def __init__(self, message, error_log=None):
+ super(_Error, self).__init__(message)
+ if error_log is None:
+ self.error_log = __copyGlobalErrorLog()
+ else:
+ self.error_log = error_log.copy()
+
+cdef object _Error = Error
+
+
+# superclass for all syntax errors
+class LxmlSyntaxError(LxmlError, SyntaxError):
+ """Base class for all syntax errors.
+ """
+
+cdef class C14NError(LxmlError):
+ """Error during C14N serialisation.
+ """
+
+# version information
+cdef tuple __unpackDottedVersion(version):
+ version_list = []
+ l = (version.decode("ascii").replace('-', '.').split('.') + [0]*4)[:4]
+ for item in l:
+ try:
+ item = int(item)
+ except ValueError:
+ if item.startswith('dev'):
+ count = item[3:]
+ item = -300
+ elif item.startswith('alpha'):
+ count = item[5:]
+ item = -200
+ elif item.startswith('beta'):
+ count = item[4:]
+ item = -100
+ else:
+ count = 0
+ if count:
+ item += int(count)
+ version_list.append(item)
+ return tuple(version_list)
+
+cdef tuple __unpackIntVersion(int c_version, int base=100):
+ return (
+ ((c_version // (base*base)) % base),
+ ((c_version // base) % base),
+ (c_version % base)
+ )
+
+cdef int _LIBXML_VERSION_INT
+try:
+ _LIBXML_VERSION_INT = int(
+ re.match('[0-9]+', (tree.xmlParserVersion).decode("ascii")).group(0))
+except Exception:
+ print("Unknown libxml2 version: " + (tree.xmlParserVersion).decode("latin1"))
+ _LIBXML_VERSION_INT = 0
+
+LIBXML_VERSION = __unpackIntVersion(_LIBXML_VERSION_INT)
+LIBXML_COMPILED_VERSION = __unpackIntVersion(tree.LIBXML_VERSION)
+LXML_VERSION = __unpackDottedVersion(tree.LXML_VERSION_STRING)
+
+__version__ = tree.LXML_VERSION_STRING.decode("ascii")
+
+cdef extern from *:
+ """
+ #ifdef ZLIB_VERNUM
+ #define __lxml_zlib_version (ZLIB_VERNUM >> 4)
+ #else
+ #define __lxml_zlib_version 0
+ #endif
+ #ifdef _LIBICONV_VERSION
+ #define __lxml_iconv_version (_LIBICONV_VERSION << 8)
+ #else
+ #define __lxml_iconv_version 0
+ #endif
+ """
+ # zlib isn't included automatically by libxml2's headers
+ #long ZLIB_HEX_VERSION "__lxml_zlib_version"
+ long LIBICONV_HEX_VERSION "__lxml_iconv_version"
+
+#ZLIB_COMPILED_VERSION = __unpackIntVersion(ZLIB_HEX_VERSION, base=0x10)
+ICONV_COMPILED_VERSION = __unpackIntVersion(LIBICONV_HEX_VERSION, base=0x100)[:2]
+
+
+cdef extern from "libxml/xmlversion.h":
+ """
+ static const char* const _lxml_lib_features[] = {
+#ifdef LIBXML_HTML_ENABLED
+ "html",
+#endif
+#ifdef LIBXML_FTP_ENABLED
+ "ftp",
+#endif
+#ifdef LIBXML_HTTP_ENABLED
+ "http",
+#endif
+#ifdef LIBXML_CATALOG_ENABLED
+ "catalog",
+#endif
+#ifdef LIBXML_XPATH_ENABLED
+ "xpath",
+#endif
+#ifdef LIBXML_ICONV_ENABLED
+ "iconv",
+#endif
+#ifdef LIBXML_ICU_ENABLED
+ "icu",
+#endif
+#ifdef LIBXML_REGEXP_ENABLED
+ "regexp",
+#endif
+#ifdef LIBXML_SCHEMAS_ENABLED
+ "xmlschema",
+#endif
+#ifdef LIBXML_SCHEMATRON_ENABLED
+ "schematron",
+#endif
+#ifdef LIBXML_ZLIB_ENABLED
+ "zlib",
+#endif
+#ifdef LIBXML_LZMA_ENABLED
+ "lzma",
+#endif
+ 0
+ };
+ """
+ const char* const* _LXML_LIB_FEATURES "_lxml_lib_features"
+
+
+cdef set _copy_lib_features():
+ features = set()
+ feature = _LXML_LIB_FEATURES
+ while feature[0]:
+ features.add(feature[0].decode('ASCII'))
+ feature += 1
+ return features
+
+LIBXML_COMPILED_FEATURES = _copy_lib_features()
+LIBXML_FEATURES = {
+ feature_name for feature_id, feature_name in [
+ #XML_WITH_THREAD = 1
+ #XML_WITH_TREE = 2
+ #XML_WITH_OUTPUT = 3
+ #XML_WITH_PUSH = 4
+ #XML_WITH_READER = 5
+ #XML_WITH_PATTERN = 6
+ #XML_WITH_WRITER = 7
+ #XML_WITH_SAX1 = 8
+ (xmlparser.XML_WITH_FTP, "ftp"), # XML_WITH_FTP = 9
+ (xmlparser.XML_WITH_HTTP, "http"), # XML_WITH_HTTP = 10
+ #XML_WITH_VALID = 11
+ (xmlparser.XML_WITH_HTML, "html"), # XML_WITH_HTML = 12
+ #XML_WITH_LEGACY = 13
+ #XML_WITH_C14N = 14
+ (xmlparser.XML_WITH_CATALOG, "catalog"), # XML_WITH_CATALOG = 15
+ (xmlparser.XML_WITH_XPATH, "xpath"), # XML_WITH_XPATH = 16
+ #XML_WITH_XPTR = 17
+ #XML_WITH_XINCLUDE = 18
+ (xmlparser.XML_WITH_ICONV, "iconv"), # XML_WITH_ICONV = 19
+ #XML_WITH_ISO8859X = 20
+ #XML_WITH_UNICODE = 21
+ (xmlparser.XML_WITH_REGEXP, "regexp"), # XML_WITH_REGEXP = 22
+ #XML_WITH_AUTOMATA = 23
+ #XML_WITH_EXPR = 24
+ (xmlparser.XML_WITH_SCHEMAS, "xmlschema"), # XML_WITH_SCHEMAS = 25
+ (xmlparser.XML_WITH_SCHEMATRON, "schematron"), # XML_WITH_SCHEMATRON = 26
+ #XML_WITH_MODULES = 27
+ #XML_WITH_DEBUG = 28
+ #XML_WITH_DEBUG_MEM = 29
+ #XML_WITH_DEBUG_RUN = 30 # unused
+ (xmlparser.XML_WITH_ZLIB, "zlib"), # XML_WITH_ZLIB = 31
+ (xmlparser.XML_WITH_ICU, "icu"), # XML_WITH_ICU = 32
+ (xmlparser.XML_WITH_LZMA, "lzma"), # XML_WITH_LZMA = 33
+ ] if xmlparser.xmlHasFeature(feature_id)
+}
+
+cdef bint HAS_ZLIB_COMPRESSION = xmlparser.xmlHasFeature(xmlparser.XML_WITH_ZLIB)
+
+
+# class for temporary storage of Python references,
+# used e.g. for XPath results
+@cython.final
+@cython.internal
+cdef class _TempStore:
+ cdef list _storage
+ def __init__(self):
+ self._storage = []
+
+ cdef int add(self, obj) except -1:
+ self._storage.append(obj)
+ return 0
+
+ cdef int clear(self) except -1:
+ del self._storage[:]
+ return 0
+
+
+# class for temporarily storing exceptions raised in extensions
+@cython.internal
+cdef class _ExceptionContext:
+ cdef object _exc_info
+ cdef int clear(self) except -1:
+ self._exc_info = None
+ return 0
+
+ cdef void _store_raised(self) noexcept:
+ try:
+ self._exc_info = sys.exc_info()
+ except BaseException as e:
+ self._store_exception(e)
+ finally:
+ return # and swallow any further exceptions
+
+ cdef int _store_exception(self, exception) except -1:
+ self._exc_info = (exception, None, None)
+ return 0
+
+ cdef bint _has_raised(self) except -1:
+ return self._exc_info is not None
+
+ cdef int _raise_if_stored(self) except -1:
+ if self._exc_info is None:
+ return 0
+ type, value, traceback = self._exc_info
+ self._exc_info = None
+ if value is None and traceback is None:
+ raise type
+ else:
+ raise type, value, traceback
+
+
+# type of a function that steps from node to node
+ctypedef public xmlNode* (*_node_to_node_function)(xmlNode*)
+
+
+################################################################################
+# Include submodules
+
+include "proxy.pxi" # Proxy handling (element backpointers/memory/etc.)
+include "apihelpers.pxi" # Private helper functions
+include "xmlerror.pxi" # Error and log handling
+
+
+################################################################################
+# Public Python API
+
+@cython.final
+@cython.freelist(8)
+cdef public class _Document [ type LxmlDocumentType, object LxmlDocument ]:
+ """Internal base class to reference a libxml document.
+
+ When instances of this class are garbage collected, the libxml
+ document is cleaned up.
+ """
+ cdef int _ns_counter
+ cdef bytes _prefix_tail
+ cdef xmlDoc* _c_doc
+ cdef _BaseParser _parser
+
+ def __dealloc__(self):
+ # if there are no more references to the document, it is safe
+ # to clean the whole thing up, as all nodes have a reference to
+ # the document
+ tree.xmlFreeDoc(self._c_doc)
+
+ @cython.final
+ cdef getroot(self):
+ # return an element proxy for the document root
+ cdef xmlNode* c_node
+ c_node = tree.xmlDocGetRootElement(self._c_doc)
+ if c_node is NULL:
+ return None
+ return _elementFactory(self, c_node)
+
+ @cython.final
+ cdef bint hasdoctype(self) noexcept:
+ # DOCTYPE gets parsed into internal subset (xmlDTD*)
+ return self._c_doc is not NULL and self._c_doc.intSubset is not NULL
+
+ @cython.final
+ cdef getdoctype(self):
+ # get doctype info: root tag, public/system ID (or None if not known)
+ cdef tree.xmlDtd* c_dtd
+ cdef xmlNode* c_root_node
+ public_id = None
+ sys_url = None
+ c_dtd = self._c_doc.intSubset
+ if c_dtd is not NULL:
+ if c_dtd.ExternalID is not NULL:
+ public_id = funicode(c_dtd.ExternalID)
+ if c_dtd.SystemID is not NULL:
+ sys_url = funicode(c_dtd.SystemID)
+ c_dtd = self._c_doc.extSubset
+ if c_dtd is not NULL:
+ if not public_id and c_dtd.ExternalID is not NULL:
+ public_id = funicode(c_dtd.ExternalID)
+ if not sys_url and c_dtd.SystemID is not NULL:
+ sys_url = funicode(c_dtd.SystemID)
+ c_root_node = tree.xmlDocGetRootElement(self._c_doc)
+ if c_root_node is NULL:
+ root_name = None
+ else:
+ root_name = funicode(c_root_node.name)
+ return root_name, public_id, sys_url
+
+ @cython.final
+ cdef getxmlinfo(self):
+ # return XML version and encoding (or None if not known)
+ cdef xmlDoc* c_doc = self._c_doc
+ if c_doc.version is NULL:
+ version = None
+ else:
+ version = funicode(c_doc.version)
+ if c_doc.encoding is NULL:
+ encoding = None
+ else:
+ encoding = funicode(c_doc.encoding)
+ return version, encoding
+
+ @cython.final
+ cdef isstandalone(self):
+ # returns True for "standalone=true",
+ # False for "standalone=false", None if not provided
+ if self._c_doc.standalone == -1:
+ return None
+ else:
+ return (self._c_doc.standalone == 1)
+
+ @cython.final
+ cdef bytes buildNewPrefix(self):
+ # get a new unique prefix ("nsX") for this document
+ cdef bytes ns
+ if self._ns_counter < len(_PREFIX_CACHE):
+ ns = _PREFIX_CACHE[self._ns_counter]
+ else:
+ ns = python.PyBytes_FromFormat("ns%d", self._ns_counter)
+ if self._prefix_tail is not None:
+ ns += self._prefix_tail
+ self._ns_counter += 1
+ if self._ns_counter < 0:
+ # overflow!
+ self._ns_counter = 0
+ if self._prefix_tail is None:
+ self._prefix_tail = b"A"
+ else:
+ self._prefix_tail += b"A"
+ return ns
+
+ @cython.final
+ cdef xmlNs* _findOrBuildNodeNs(self, xmlNode* c_node,
+ const_xmlChar* c_href, const_xmlChar* c_prefix,
+ bint is_attribute) except NULL:
+ """Get or create namespace structure for a node. Reuses the prefix if
+ possible.
+ """
+ cdef xmlNs* c_ns
+ cdef xmlNs* c_doc_ns
+ cdef python.PyObject* dict_result
+ if c_node.type != tree.XML_ELEMENT_NODE:
+ assert c_node.type == tree.XML_ELEMENT_NODE, \
+ "invalid node type %d, expected %d" % (
+ c_node.type, tree.XML_ELEMENT_NODE)
+ # look for existing ns declaration
+ c_ns = _searchNsByHref(c_node, c_href, is_attribute)
+ if c_ns is not NULL:
+ if is_attribute and c_ns.prefix is NULL:
+ # do not put namespaced attributes into the default
+ # namespace as this would break serialisation
+ pass
+ else:
+ return c_ns
+
+ # none found => determine a suitable new prefix
+ if c_prefix is NULL:
+ dict_result = python.PyDict_GetItem(
+ _DEFAULT_NAMESPACE_PREFIXES, c_href)
+ if dict_result is not NULL:
+ prefix = dict_result
+ else:
+ prefix = self.buildNewPrefix()
+ c_prefix = _xcstr(prefix)
+
+ # make sure the prefix is not in use already
+ while tree.xmlSearchNs(self._c_doc, c_node, c_prefix) is not NULL:
+ prefix = self.buildNewPrefix()
+ c_prefix = _xcstr(prefix)
+
+ # declare the namespace and return it
+ c_ns = tree.xmlNewNs(c_node, c_href, c_prefix)
+ if c_ns is NULL:
+ raise MemoryError()
+ return c_ns
+
+ @cython.final
+ cdef int _setNodeNs(self, xmlNode* c_node, const_xmlChar* c_href) except -1:
+ "Lookup namespace structure and set it for the node."
+ c_ns = self._findOrBuildNodeNs(c_node, c_href, NULL, 0)
+ tree.xmlSetNs(c_node, c_ns)
+
+
+cdef tuple __initPrefixCache():
+ cdef int i
+ return tuple([ python.PyBytes_FromFormat("ns%d", i)
+ for i in range(26) ])
+
+cdef tuple _PREFIX_CACHE = __initPrefixCache()
+
+
+cdef _Document _documentFactory(xmlDoc* c_doc, _BaseParser parser):
+ cdef _Document result
+ result = _Document.__new__(_Document)
+ result._c_doc = c_doc
+ result._ns_counter = 0
+ result._prefix_tail = None
+ if parser is None:
+ parser = __GLOBAL_PARSER_CONTEXT.getDefaultParser()
+ result._parser = parser
+ return result
+
+
+cdef object _find_invalid_public_id_characters = re.compile(
+ ur"[^\x20\x0D\x0Aa-zA-Z0-9'()+,./:=?;!*#@$_%-]+").search
+
+
+cdef class DocInfo:
+ "Document information provided by parser and DTD."
+ cdef _Document _doc
+ def __cinit__(self, tree):
+ "Create a DocInfo object for an ElementTree object or root Element."
+ self._doc = _documentOrRaise(tree)
+ root_name, public_id, system_url = self._doc.getdoctype()
+ if not root_name and (public_id or system_url):
+ raise ValueError, "Could not find root node"
+
+ @property
+ def root_name(self):
+ """Returns the name of the root node as defined by the DOCTYPE."""
+ root_name, public_id, system_url = self._doc.getdoctype()
+ return root_name
+
+ @cython.final
+ cdef tree.xmlDtd* _get_c_dtd(self):
+ """"Return the DTD. Create it if it does not yet exist."""
+ cdef xmlDoc* c_doc = self._doc._c_doc
+ cdef xmlNode* c_root_node
+ cdef const_xmlChar* c_name
+
+ if c_doc.intSubset:
+ return c_doc.intSubset
+
+ c_root_node = tree.xmlDocGetRootElement(c_doc)
+ c_name = c_root_node.name if c_root_node else NULL
+ return tree.xmlCreateIntSubset(c_doc, c_name, NULL, NULL)
+
+ def clear(self):
+ """Removes DOCTYPE and internal subset from the document."""
+ cdef xmlDoc* c_doc = self._doc._c_doc
+ cdef tree.xmlNode* c_dtd = c_doc.intSubset
+ if c_dtd is NULL:
+ return
+ tree.xmlUnlinkNode(c_dtd)
+ tree.xmlFreeNode(c_dtd)
+
+ property public_id:
+ """Public ID of the DOCTYPE.
+
+ Mutable. May be set to a valid string or None. If a DTD does not
+ exist, setting this variable (even to None) will create one.
+ """
+ def __get__(self):
+ root_name, public_id, system_url = self._doc.getdoctype()
+ return public_id
+
+ def __set__(self, value):
+ cdef xmlChar* c_value = NULL
+ if value is not None:
+ match = _find_invalid_public_id_characters(value)
+ if match:
+ raise ValueError, f'Invalid character(s) {match.group(0)!r} in public_id.'
+ value = _utf8(value)
+ c_value = tree.xmlStrdup(_xcstr(value))
+ if not c_value:
+ raise MemoryError()
+
+ c_dtd = self._get_c_dtd()
+ if not c_dtd:
+ tree.xmlFree(c_value)
+ raise MemoryError()
+ if c_dtd.ExternalID:
+ tree.xmlFree(c_dtd.ExternalID)
+ c_dtd.ExternalID = c_value
+
+ property system_url:
+ """System ID of the DOCTYPE.
+
+ Mutable. May be set to a valid string or None. If a DTD does not
+ exist, setting this variable (even to None) will create one.
+ """
+ def __get__(self):
+ root_name, public_id, system_url = self._doc.getdoctype()
+ return system_url
+
+ def __set__(self, value):
+ cdef xmlChar* c_value = NULL
+ if value is not None:
+ bvalue = _utf8(value)
+ # sys_url may be any valid unicode string that can be
+ # enclosed in single quotes or quotes.
+ if b"'" in bvalue and b'"' in bvalue:
+ raise ValueError(
+ 'System URL may not contain both single (\') and double quotes (").')
+ c_value = tree.xmlStrdup(_xcstr(bvalue))
+ if not c_value:
+ raise MemoryError()
+
+ c_dtd = self._get_c_dtd()
+ if not c_dtd:
+ tree.xmlFree(c_value)
+ raise MemoryError()
+ if c_dtd.SystemID:
+ tree.xmlFree(c_dtd.SystemID)
+ c_dtd.SystemID = c_value
+
+ @property
+ def xml_version(self):
+ """Returns the XML version as declared by the document."""
+ xml_version, encoding = self._doc.getxmlinfo()
+ return xml_version
+
+ @property
+ def encoding(self):
+ """Returns the encoding name as declared by the document."""
+ xml_version, encoding = self._doc.getxmlinfo()
+ return encoding
+
+ @property
+ def standalone(self):
+ """Returns the standalone flag as declared by the document. The possible
+ values are True (``standalone='yes'``), False
+ (``standalone='no'`` or flag not provided in the declaration),
+ and None (unknown or no declaration found). Note that a
+ normal truth test on this value will always tell if the
+ ``standalone`` flag was set to ``'yes'`` or not.
+ """
+ return self._doc.isstandalone()
+
+ property URL:
+ "The source URL of the document (or None if unknown)."
+ def __get__(self):
+ if self._doc._c_doc.URL is NULL:
+ return None
+ return _decodeFilename(self._doc._c_doc.URL)
+ def __set__(self, url):
+ url = _encodeFilename(url)
+ c_oldurl = self._doc._c_doc.URL
+ if url is None:
+ self._doc._c_doc.URL = NULL
+ else:
+ self._doc._c_doc.URL = tree.xmlStrdup(_xcstr(url))
+ if c_oldurl is not NULL:
+ tree.xmlFree(c_oldurl)
+
+ @property
+ def doctype(self):
+ """Returns a DOCTYPE declaration string for the document."""
+ root_name, public_id, system_url = self._doc.getdoctype()
+ if system_url:
+ # If '"' in system_url, we must escape it with single
+ # quotes, otherwise escape with double quotes. If url
+ # contains both a single quote and a double quote, XML
+ # standard is being violated.
+ if '"' in system_url:
+ quoted_system_url = f"'{system_url}'"
+ else:
+ quoted_system_url = f'"{system_url}"'
+ if public_id:
+ if system_url:
+ return f''
+ else:
+ return f''
+ elif system_url:
+ return f''
+ elif self._doc.hasdoctype():
+ return f''
+ else:
+ return ''
+
+ @property
+ def internalDTD(self):
+ """Returns a DTD validator based on the internal subset of the document."""
+ return _dtdFactory(self._doc._c_doc.intSubset)
+
+ @property
+ def externalDTD(self):
+ """Returns a DTD validator based on the external subset of the document."""
+ return _dtdFactory(self._doc._c_doc.extSubset)
+
+
+@cython.no_gc_clear
+cdef public class _Element [ type LxmlElementType, object LxmlElement ]:
+ """Element class.
+
+ References a document object and a libxml node.
+
+ By pointing to a Document instance, a reference is kept to
+ _Document as long as there is some pointer to a node in it.
+ """
+ cdef _Document _doc
+ cdef xmlNode* _c_node
+ cdef object _tag
+
+ def _init(self):
+ """_init(self)
+
+ Called after object initialisation. Custom subclasses may override
+ this if they recursively call _init() in the superclasses.
+ """
+
+ @cython.linetrace(False)
+ @cython.profile(False)
+ def __dealloc__(self):
+ #print("trying to free node:", self._c_node)
+ #displayNode(self._c_node, 0)
+ if self._c_node is not NULL:
+ _unregisterProxy(self)
+ attemptDeallocation(self._c_node)
+
+ # MANIPULATORS
+
+ def __setitem__(self, x, value):
+ """__setitem__(self, x, value)
+
+ Replaces the given subelement index or slice.
+ """
+ cdef xmlNode* c_node = NULL
+ cdef xmlNode* c_next
+ cdef xmlDoc* c_source_doc
+ cdef _Element element
+ cdef bint left_to_right
+ cdef Py_ssize_t slicelength = 0, step = 0
+ _assertValidNode(self)
+ if value is None:
+ raise ValueError, "cannot assign None"
+ if isinstance(x, slice):
+ # slice assignment
+ _findChildSlice(x, self._c_node, &c_node, &step, &slicelength)
+ if step > 0:
+ left_to_right = 1
+ else:
+ left_to_right = 0
+ step = -step
+ _replaceSlice(self, c_node, slicelength, step, left_to_right, value)
+ return
+ else:
+ # otherwise: normal item assignment
+ element = value
+ _assertValidNode(element)
+ c_node = _findChild(self._c_node, x)
+ if c_node is NULL:
+ raise IndexError, "list index out of range"
+ c_source_doc = element._c_node.doc
+ c_next = element._c_node.next
+ _removeText(c_node.next)
+ tree.xmlReplaceNode(c_node, element._c_node)
+ _moveTail(c_next, element._c_node)
+ moveNodeToDocument(self._doc, c_source_doc, element._c_node)
+ if not attemptDeallocation(c_node):
+ moveNodeToDocument(self._doc, c_node.doc, c_node)
+
+ def __delitem__(self, x):
+ """__delitem__(self, x)
+
+ Deletes the given subelement or a slice.
+ """
+ cdef xmlNode* c_node = NULL
+ cdef xmlNode* c_next
+ cdef Py_ssize_t step = 0, slicelength = 0
+ _assertValidNode(self)
+ if isinstance(x, slice):
+ # slice deletion
+ if _isFullSlice(x):
+ c_node = self._c_node.children
+ if c_node is not NULL:
+ if not _isElement(c_node):
+ c_node = _nextElement(c_node)
+ while c_node is not NULL:
+ c_next = _nextElement(c_node)
+ _removeNode(self._doc, c_node)
+ c_node = c_next
+ else:
+ _findChildSlice(x, self._c_node, &c_node, &step, &slicelength)
+ _deleteSlice(self._doc, c_node, slicelength, step)
+ else:
+ # item deletion
+ c_node = _findChild(self._c_node, x)
+ if c_node is NULL:
+ raise IndexError, f"index out of range: {x}"
+ _removeNode(self._doc, c_node)
+
+ def __deepcopy__(self, memo):
+ "__deepcopy__(self, memo)"
+ return self.__copy__()
+
+ def __copy__(self):
+ "__copy__(self)"
+ cdef xmlDoc* c_doc
+ cdef xmlNode* c_node
+ cdef _Document new_doc
+ _assertValidNode(self)
+ c_doc = _copyDocRoot(self._doc._c_doc, self._c_node) # recursive
+ new_doc = _documentFactory(c_doc, self._doc._parser)
+ root = new_doc.getroot()
+ if root is not None:
+ return root
+ # Comment/PI
+ c_node = c_doc.children
+ while c_node is not NULL and c_node.type != self._c_node.type:
+ c_node = c_node.next
+ if c_node is NULL:
+ return None
+ return _elementFactory(new_doc, c_node)
+
+ def set(self, key, value):
+ """set(self, key, value)
+
+ Sets an element attribute.
+ In HTML documents (not XML or XHTML), the value None is allowed and creates
+ an attribute without value (just the attribute name).
+ """
+ _assertValidNode(self)
+ _setAttributeValue(self, key, value)
+
+ def append(self, _Element element not None):
+ """append(self, element)
+
+ Adds a subelement to the end of this element.
+ """
+ _assertValidNode(self)
+ _assertValidNode(element)
+ _appendChild(self, element)
+
+ def addnext(self, _Element element not None):
+ """addnext(self, element)
+
+ Adds the element as a following sibling directly after this
+ element.
+
+ This is normally used to set a processing instruction or comment after
+ the root node of a document. Note that tail text is automatically
+ discarded when adding at the root level.
+ """
+ _assertValidNode(self)
+ _assertValidNode(element)
+ if self._c_node.parent != NULL and not _isElement(self._c_node.parent):
+ if element._c_node.type not in (tree.XML_PI_NODE, tree.XML_COMMENT_NODE):
+ raise TypeError, "Only processing instructions and comments can be siblings of the root element"
+ element.tail = None
+ _appendSibling(self, element)
+
+ def addprevious(self, _Element element not None):
+ """addprevious(self, element)
+
+ Adds the element as a preceding sibling directly before this
+ element.
+
+ This is normally used to set a processing instruction or comment
+ before the root node of a document. Note that tail text is
+ automatically discarded when adding at the root level.
+ """
+ _assertValidNode(self)
+ _assertValidNode(element)
+ if self._c_node.parent != NULL and not _isElement(self._c_node.parent):
+ if element._c_node.type != tree.XML_PI_NODE:
+ if element._c_node.type != tree.XML_COMMENT_NODE:
+ raise TypeError, "Only processing instructions and comments can be siblings of the root element"
+ element.tail = None
+ _prependSibling(self, element)
+
+ def extend(self, elements):
+ """extend(self, elements)
+
+ Extends the current children by the elements in the iterable.
+ """
+ cdef _Element element
+ _assertValidNode(self)
+ for element in elements:
+ if element is None:
+ raise TypeError, "Node must not be None"
+ _assertValidNode(element)
+ _appendChild(self, element)
+
+ def clear(self, bint keep_tail=False):
+ """clear(self, keep_tail=False)
+
+ Resets an element. This function removes all subelements, clears
+ all attributes and sets the text and tail properties to None.
+
+ Pass ``keep_tail=True`` to leave the tail text untouched.
+ """
+ cdef xmlAttr* c_attr
+ cdef xmlAttr* c_attr_next
+ cdef xmlNode* c_node
+ cdef xmlNode* c_node_next
+ _assertValidNode(self)
+ c_node = self._c_node
+ # remove self.text and self.tail
+ _removeText(c_node.children)
+ if not keep_tail:
+ _removeText(c_node.next)
+ # remove all attributes
+ c_attr = c_node.properties
+ if c_attr:
+ c_node.properties = NULL
+ tree.xmlFreePropList(c_attr)
+ # remove all subelements
+ c_node = c_node.children
+ if c_node and not _isElement(c_node):
+ c_node = _nextElement(c_node)
+ while c_node is not NULL:
+ c_node_next = _nextElement(c_node)
+ _removeNode(self._doc, c_node)
+ c_node = c_node_next
+
+ def insert(self, index: int, _Element element not None):
+ """insert(self, index, element)
+
+ Inserts a subelement at the given position in this element
+ """
+ cdef xmlNode* c_node
+ cdef xmlNode* c_next
+ cdef xmlDoc* c_source_doc
+ _assertValidNode(self)
+ _assertValidNode(element)
+ c_node = _findChild(self._c_node, index)
+ if c_node is NULL:
+ _appendChild(self, element)
+ return
+ # prevent cycles
+ if _isAncestorOrSame(element._c_node, self._c_node):
+ raise ValueError("cannot append parent to itself")
+ c_source_doc = element._c_node.doc
+ c_next = element._c_node.next
+ tree.xmlAddPrevSibling(c_node, element._c_node)
+ _moveTail(c_next, element._c_node)
+ moveNodeToDocument(self._doc, c_source_doc, element._c_node)
+
+ def remove(self, _Element element not None):
+ """remove(self, element)
+
+ Removes a matching subelement. Unlike the find methods, this
+ method compares elements based on identity, not on tag value
+ or contents.
+ """
+ cdef xmlNode* c_node
+ cdef xmlNode* c_next
+ _assertValidNode(self)
+ _assertValidNode(element)
+ c_node = element._c_node
+ if c_node.parent is not self._c_node:
+ raise ValueError, "Element is not a child of this node."
+ c_next = element._c_node.next
+ tree.xmlUnlinkNode(c_node)
+ _moveTail(c_next, c_node)
+ # fix namespace declarations
+ moveNodeToDocument(self._doc, c_node.doc, c_node)
+
+ def replace(self, _Element old_element not None,
+ _Element new_element not None):
+ """replace(self, old_element, new_element)
+
+ Replaces a subelement with the element passed as second argument.
+ """
+ cdef xmlNode* c_old_node
+ cdef xmlNode* c_old_next
+ cdef xmlNode* c_new_node
+ cdef xmlNode* c_new_next
+ cdef xmlDoc* c_source_doc
+ _assertValidNode(self)
+ _assertValidNode(old_element)
+ _assertValidNode(new_element)
+ c_old_node = old_element._c_node
+ if c_old_node.parent is not self._c_node:
+ raise ValueError, "Element is not a child of this node."
+ c_new_node = new_element._c_node
+ # prevent cycles
+ if _isAncestorOrSame(c_new_node, self._c_node):
+ raise ValueError("cannot append parent to itself")
+ # replace node
+ c_old_next = c_old_node.next
+ c_new_next = c_new_node.next
+ c_source_doc = c_new_node.doc
+ tree.xmlReplaceNode(c_old_node, c_new_node)
+ _moveTail(c_new_next, c_new_node)
+ _moveTail(c_old_next, c_old_node)
+ moveNodeToDocument(self._doc, c_source_doc, c_new_node)
+ # fix namespace declarations
+ moveNodeToDocument(self._doc, c_old_node.doc, c_old_node)
+
+ # PROPERTIES
+ property tag:
+ """Element tag
+ """
+ def __get__(self):
+ if self._tag is not None:
+ return self._tag
+ _assertValidNode(self)
+ self._tag = _namespacedName(self._c_node)
+ return self._tag
+
+ def __set__(self, value):
+ cdef _BaseParser parser
+ _assertValidNode(self)
+ ns, name = _getNsTag(value)
+ parser = self._doc._parser
+ if parser is not None and parser._for_html:
+ _htmlTagValidOrRaise(name)
+ else:
+ _tagValidOrRaise(name)
+ self._tag = value
+ tree.xmlNodeSetName(self._c_node, _xcstr(name))
+ if ns is None:
+ self._c_node.ns = NULL
+ else:
+ self._doc._setNodeNs(self._c_node, _xcstr(ns))
+
+ @property
+ def attrib(self):
+ """Element attribute dictionary. Where possible, use get(), set(),
+ keys(), values() and items() to access element attributes.
+ """
+ return _Attrib.__new__(_Attrib, self)
+
+ property text:
+ """Text before the first subelement. This is either a string or
+ the value None, if there was no text.
+ """
+ def __get__(self):
+ _assertValidNode(self)
+ return _collectText(self._c_node.children)
+
+ def __set__(self, value):
+ _assertValidNode(self)
+ if isinstance(value, QName):
+ value = _resolveQNameText(self, value).decode('utf8')
+ _setNodeText(self._c_node, value)
+
+ # using 'del el.text' is the wrong thing to do
+ #def __del__(self):
+ # _setNodeText(self._c_node, None)
+
+ property tail:
+ """Text after this element's end tag, but before the next sibling
+ element's start tag. This is either a string or the value None, if
+ there was no text.
+ """
+ def __get__(self):
+ _assertValidNode(self)
+ return _collectText(self._c_node.next)
+
+ def __set__(self, value):
+ _assertValidNode(self)
+ _setTailText(self._c_node, value)
+
+ # using 'del el.tail' is the wrong thing to do
+ #def __del__(self):
+ # _setTailText(self._c_node, None)
+
+ # not in ElementTree, read-only
+ @property
+ def prefix(self):
+ """Namespace prefix or None.
+ """
+ if self._c_node.ns is not NULL:
+ if self._c_node.ns.prefix is not NULL:
+ return funicode(self._c_node.ns.prefix)
+ return None
+
+ # not in ElementTree, read-only
+ property sourceline:
+ """Original line number as found by the parser or None if unknown.
+ """
+ def __get__(self):
+ cdef long line
+ _assertValidNode(self)
+ line = tree.xmlGetLineNo(self._c_node)
+ return line if line > 0 else None
+
+ def __set__(self, line):
+ _assertValidNode(self)
+ if line <= 0:
+ self._c_node.line = 0
+ else:
+ self._c_node.line = line
+
+ # not in ElementTree, read-only
+ @property
+ def nsmap(self):
+ """Namespace prefix->URI mapping known in the context of this
+ Element. This includes all namespace declarations of the
+ parents.
+
+ Note that changing the returned dict has no effect on the Element.
+ """
+ _assertValidNode(self)
+ return _build_nsmap(self._c_node)
+
+ # not in ElementTree, read-only
+ property base:
+ """The base URI of the Element (xml:base or HTML base URL).
+ None if the base URI is unknown.
+
+ Note that the value depends on the URL of the document that
+ holds the Element if there is no xml:base attribute on the
+ Element or its ancestors.
+
+ Setting this property will set an xml:base attribute on the
+ Element, regardless of the document type (XML or HTML).
+ """
+ def __get__(self):
+ _assertValidNode(self)
+ c_base = tree.xmlNodeGetBase(self._doc._c_doc, self._c_node)
+ if c_base is NULL:
+ if self._doc._c_doc.URL is NULL:
+ return None
+ return _decodeFilename(self._doc._c_doc.URL)
+ try:
+ base = _decodeFilename(c_base)
+ finally:
+ tree.xmlFree(c_base)
+ return base
+
+ def __set__(self, url):
+ _assertValidNode(self)
+ if url is None:
+ c_base = NULL
+ else:
+ url = _encodeFilename(url)
+ c_base = _xcstr(url)
+ tree.xmlNodeSetBase(self._c_node, c_base)
+
+ # ACCESSORS
+ def __repr__(self):
+ "__repr__(self)"
+ return "" % (self.tag, id(self))
+
+ def __getitem__(self, x):
+ """Returns the subelement at the given position or the requested
+ slice.
+ """
+ cdef xmlNode* c_node = NULL
+ cdef Py_ssize_t step = 0, slicelength = 0
+ cdef Py_ssize_t c, i
+ cdef _node_to_node_function next_element
+ cdef list result
+ _assertValidNode(self)
+ if isinstance(x, slice):
+ # slicing
+ if _isFullSlice(x):
+ return _collectChildren(self)
+ _findChildSlice(x, self._c_node, &c_node, &step, &slicelength)
+ if c_node is NULL:
+ return []
+ if step > 0:
+ next_element = _nextElement
+ else:
+ step = -step
+ next_element = _previousElement
+ result = []
+ c = 0
+ while c_node is not NULL and c < slicelength:
+ result.append(_elementFactory(self._doc, c_node))
+ c += 1
+ for i in range(step):
+ c_node = next_element(c_node)
+ if c_node is NULL:
+ break
+ return result
+ else:
+ # indexing
+ c_node = _findChild(self._c_node, x)
+ if c_node is NULL:
+ raise IndexError, "list index out of range"
+ return _elementFactory(self._doc, c_node)
+
+ def __len__(self):
+ """__len__(self)
+
+ Returns the number of subelements.
+ """
+ _assertValidNode(self)
+ return _countElements(self._c_node.children)
+
+ def __bool__(self):
+ """__bool__(self)"""
+ import warnings
+ warnings.warn(
+ "Truth-testing of elements was a source of confusion and will always "
+ "return True in future versions. "
+ "Use specific 'len(elem)' or 'elem is not None' test instead.",
+ FutureWarning
+ )
+ # emulate old behaviour
+ _assertValidNode(self)
+ return _hasChild(self._c_node)
+
+ def __contains__(self, element):
+ "__contains__(self, element)"
+ cdef xmlNode* c_node
+ _assertValidNode(self)
+ if not isinstance(element, _Element):
+ return 0
+ c_node = (<_Element>element)._c_node
+ return c_node is not NULL and c_node.parent is self._c_node
+
+ def __iter__(self):
+ "__iter__(self)"
+ return ElementChildIterator(self)
+
+ def __reversed__(self):
+ "__reversed__(self)"
+ return ElementChildIterator(self, reversed=True)
+
+ def index(self, child: _Element, start: int = None, stop: int = None):
+ """index(self, child, start=None, stop=None)
+
+ Find the position of the child within the parent.
+
+ This method is not part of the original ElementTree API.
+ """
+ cdef Py_ssize_t k, l
+ cdef Py_ssize_t c_start, c_stop
+ cdef xmlNode* c_child
+ cdef xmlNode* c_start_node
+ _assertValidNode(self)
+ _assertValidNode(child)
+ c_child = child._c_node
+ if c_child.parent is not self._c_node:
+ raise ValueError, "Element is not a child of this node."
+
+ # handle the unbounded search straight away (normal case)
+ if stop is None and (start is None or start == 0):
+ k = 0
+ c_child = c_child.prev
+ while c_child is not NULL:
+ if _isElement(c_child):
+ k += 1
+ c_child = c_child.prev
+ return k
+
+ # check indices
+ if start is None:
+ c_start = 0
+ else:
+ c_start = start
+ if stop is None:
+ c_stop = 0
+ else:
+ c_stop = stop
+ if c_stop == 0 or \
+ c_start >= c_stop and (c_stop > 0 or c_start < 0):
+ raise ValueError, "list.index(x): x not in slice"
+
+ # for negative slice indices, check slice before searching index
+ if c_start < 0 or c_stop < 0:
+ # start from right, at most up to leftmost(c_start, c_stop)
+ if c_start < c_stop:
+ k = -c_start
+ else:
+ k = -c_stop
+ c_start_node = self._c_node.last
+ l = 1
+ while c_start_node != c_child and l < k:
+ if _isElement(c_start_node):
+ l += 1
+ c_start_node = c_start_node.prev
+ if c_start_node == c_child:
+ # found! before slice end?
+ if c_stop < 0 and l <= -c_stop:
+ raise ValueError, "list.index(x): x not in slice"
+ elif c_start < 0:
+ raise ValueError, "list.index(x): x not in slice"
+
+ # now determine the index backwards from child
+ c_child = c_child.prev
+ k = 0
+ if c_stop > 0:
+ # we can optimize: stop after c_stop elements if not found
+ while c_child != NULL and k < c_stop:
+ if _isElement(c_child):
+ k += 1
+ c_child = c_child.prev
+ if k < c_stop:
+ return k
+ else:
+ # traverse all
+ while c_child != NULL:
+ if _isElement(c_child):
+ k = k + 1
+ c_child = c_child.prev
+ if c_start > 0:
+ if k >= c_start:
+ return k
+ else:
+ return k
+ if c_start != 0 or c_stop != 0:
+ raise ValueError, "list.index(x): x not in slice"
+ else:
+ raise ValueError, "list.index(x): x not in list"
+
+ def get(self, key, default=None):
+ """get(self, key, default=None)
+
+ Gets an element attribute.
+ """
+ _assertValidNode(self)
+ return _getAttributeValue(self, key, default)
+
+ def keys(self):
+ """keys(self)
+
+ Gets a list of attribute names. The names are returned in an
+ arbitrary order (just like for an ordinary Python dictionary).
+ """
+ _assertValidNode(self)
+ return _collectAttributes(self._c_node, 1)
+
+ def values(self):
+ """values(self)
+
+ Gets element attribute values as a sequence of strings. The
+ attributes are returned in an arbitrary order.
+ """
+ _assertValidNode(self)
+ return _collectAttributes(self._c_node, 2)
+
+ def items(self):
+ """items(self)
+
+ Gets element attributes, as a sequence. The attributes are returned in
+ an arbitrary order.
+ """
+ _assertValidNode(self)
+ return _collectAttributes(self._c_node, 3)
+
+ def getchildren(self):
+ """getchildren(self)
+
+ Returns all direct children. The elements are returned in document
+ order.
+
+ :deprecated: Note that this method has been deprecated as of
+ ElementTree 1.3 and lxml 2.0. New code should use
+ ``list(element)`` or simply iterate over elements.
+ """
+ _assertValidNode(self)
+ return _collectChildren(self)
+
+ def getparent(self):
+ """getparent(self)
+
+ Returns the parent of this element or None for the root element.
+ """
+ cdef xmlNode* c_node
+ #_assertValidNode(self) # not needed
+ c_node = _parentElement(self._c_node)
+ if c_node is NULL:
+ return None
+ return _elementFactory(self._doc, c_node)
+
+ def getnext(self):
+ """getnext(self)
+
+ Returns the following sibling of this element or None.
+ """
+ cdef xmlNode* c_node
+ #_assertValidNode(self) # not needed
+ c_node = _nextElement(self._c_node)
+ if c_node is NULL:
+ return None
+ return _elementFactory(self._doc, c_node)
+
+ def getprevious(self):
+ """getprevious(self)
+
+ Returns the preceding sibling of this element or None.
+ """
+ cdef xmlNode* c_node
+ #_assertValidNode(self) # not needed
+ c_node = _previousElement(self._c_node)
+ if c_node is NULL:
+ return None
+ return _elementFactory(self._doc, c_node)
+
+ def itersiblings(self, tag=None, *tags, preceding=False):
+ """itersiblings(self, tag=None, *tags, preceding=False)
+
+ Iterate over the following or preceding siblings of this element.
+
+ The direction is determined by the 'preceding' keyword which
+ defaults to False, i.e. forward iteration over the following
+ siblings. When True, the iterator yields the preceding
+ siblings in reverse document order, i.e. starting right before
+ the current element and going backwards.
+
+ Can be restricted to find only elements with specific tags,
+ see `iter`.
+ """
+ if preceding:
+ if self._c_node and not self._c_node.prev:
+ return ITER_EMPTY
+ elif self._c_node and not self._c_node.next:
+ return ITER_EMPTY
+ if tag is not None:
+ tags += (tag,)
+ return SiblingsIterator(self, tags, preceding=preceding)
+
+ def iterancestors(self, tag=None, *tags):
+ """iterancestors(self, tag=None, *tags)
+
+ Iterate over the ancestors of this element (from parent to parent).
+
+ Can be restricted to find only elements with specific tags,
+ see `iter`.
+ """
+ if self._c_node and not self._c_node.parent:
+ return ITER_EMPTY
+ if tag is not None:
+ tags += (tag,)
+ return AncestorsIterator(self, tags)
+
+ def iterdescendants(self, tag=None, *tags):
+ """iterdescendants(self, tag=None, *tags)
+
+ Iterate over the descendants of this element in document order.
+
+ As opposed to ``el.iter()``, this iterator does not yield the element
+ itself. The returned elements can be restricted to find only elements
+ with specific tags, see `iter`.
+ """
+ if self._c_node and not self._c_node.children:
+ return ITER_EMPTY
+ if tag is not None:
+ tags += (tag,)
+ return ElementDepthFirstIterator(self, tags, inclusive=False)
+
+ def iterchildren(self, tag=None, *tags, reversed=False):
+ """iterchildren(self, tag=None, *tags, reversed=False)
+
+ Iterate over the children of this element.
+
+ As opposed to using normal iteration on this element, the returned
+ elements can be reversed with the 'reversed' keyword and restricted
+ to find only elements with specific tags, see `iter`.
+ """
+ if self._c_node and not self._c_node.children:
+ return ITER_EMPTY
+ if tag is not None:
+ tags += (tag,)
+ return ElementChildIterator(self, tags, reversed=reversed)
+
+ def getroottree(self):
+ """getroottree(self)
+
+ Return an ElementTree for the root node of the document that
+ contains this element.
+
+ This is the same as following element.getparent() up the tree until it
+ returns None (for the root element) and then build an ElementTree for
+ the last parent that was returned."""
+ _assertValidDoc(self._doc)
+ return _elementTreeFactory(self._doc, None)
+
+ def getiterator(self, tag=None, *tags):
+ """getiterator(self, tag=None, *tags)
+
+ Returns a sequence or iterator of all elements in the subtree in
+ document order (depth first pre-order), starting with this
+ element.
+
+ Can be restricted to find only elements with specific tags,
+ see `iter`.
+
+ :deprecated: Note that this method is deprecated as of
+ ElementTree 1.3 and lxml 2.0. It returns an iterator in
+ lxml, which diverges from the original ElementTree
+ behaviour. If you want an efficient iterator, use the
+ ``element.iter()`` method instead. You should only use this
+ method in new code if you require backwards compatibility
+ with older versions of lxml or ElementTree.
+ """
+ if tag is not None:
+ tags += (tag,)
+ return ElementDepthFirstIterator(self, tags)
+
+ def iter(self, tag=None, *tags):
+ """iter(self, tag=None, *tags)
+
+ Iterate over all elements in the subtree in document order (depth
+ first pre-order), starting with this element.
+
+ Can be restricted to find only elements with specific tags:
+ pass ``"{ns}localname"`` as tag. Either or both of ``ns`` and
+ ``localname`` can be ``*`` for a wildcard; ``ns`` can be empty
+ for no namespace. ``"localname"`` is equivalent to ``"{}localname"``
+ (i.e. no namespace) but ``"*"`` is ``"{*}*"`` (any or no namespace),
+ not ``"{}*"``.
+
+ You can also pass the Element, Comment, ProcessingInstruction and
+ Entity factory functions to look only for the specific element type.
+
+ Passing multiple tags (or a sequence of tags) instead of a single tag
+ will let the iterator return all elements matching any of these tags,
+ in document order.
+ """
+ if tag is not None:
+ tags += (tag,)
+ return ElementDepthFirstIterator(self, tags)
+
+ def itertext(self, tag=None, *tags, with_tail=True):
+ """itertext(self, tag=None, *tags, with_tail=True)
+
+ Iterates over the text content of a subtree.
+
+ You can pass tag names to restrict text content to specific elements,
+ see `iter`.
+
+ You can set the ``with_tail`` keyword argument to ``False`` to skip
+ over tail text.
+ """
+ if tag is not None:
+ tags += (tag,)
+ return ElementTextIterator(self, tags, with_tail=with_tail)
+
+ def makeelement(self, _tag, attrib=None, nsmap=None, **_extra):
+ """makeelement(self, _tag, attrib=None, nsmap=None, **_extra)
+
+ Creates a new element associated with the same document.
+ """
+ _assertValidDoc(self._doc)
+ return _makeElement(_tag, NULL, self._doc, None, None, None,
+ attrib, nsmap, _extra)
+
+ def find(self, path, namespaces=None):
+ """find(self, path, namespaces=None)
+
+ Finds the first matching subelement, by tag name or path.
+
+ The optional ``namespaces`` argument accepts a
+ prefix-to-namespace mapping that allows the usage of XPath
+ prefixes in the path expression.
+ """
+ if isinstance(path, QName):
+ path = (path).text
+ return _elementpath.find(self, path, namespaces, with_prefixes=not _isHtmlDocument(self))
+
+ def findtext(self, path, default=None, namespaces=None):
+ """findtext(self, path, default=None, namespaces=None)
+
+ Finds text for the first matching subelement, by tag name or path.
+
+ The optional ``namespaces`` argument accepts a
+ prefix-to-namespace mapping that allows the usage of XPath
+ prefixes in the path expression.
+ """
+ if isinstance(path, QName):
+ path = (path).text
+ return _elementpath.findtext(self, path, default, namespaces, with_prefixes=not _isHtmlDocument(self))
+
+ def findall(self, path, namespaces=None):
+ """findall(self, path, namespaces=None)
+
+ Finds all matching subelements, by tag name or path.
+
+ The optional ``namespaces`` argument accepts a
+ prefix-to-namespace mapping that allows the usage of XPath
+ prefixes in the path expression.
+ """
+ if isinstance(path, QName):
+ path = (path).text
+ return _elementpath.findall(self, path, namespaces, with_prefixes=not _isHtmlDocument(self))
+
+ def iterfind(self, path, namespaces=None):
+ """iterfind(self, path, namespaces=None)
+
+ Iterates over all matching subelements, by tag name or path.
+
+ The optional ``namespaces`` argument accepts a
+ prefix-to-namespace mapping that allows the usage of XPath
+ prefixes in the path expression.
+ """
+ if isinstance(path, QName):
+ path = (path).text
+ return _elementpath.iterfind(self, path, namespaces, with_prefixes=not _isHtmlDocument(self))
+
+ def xpath(self, _path, *, namespaces=None, extensions=None,
+ smart_strings=True, **_variables):
+ """xpath(self, _path, namespaces=None, extensions=None, smart_strings=True, **_variables)
+
+ Evaluate an xpath expression using the element as context node.
+ """
+ evaluator = XPathElementEvaluator(self, namespaces=namespaces,
+ extensions=extensions,
+ smart_strings=smart_strings)
+ return evaluator(_path, **_variables)
+
+ def cssselect(self, expr, *, translator='xml'):
+ """
+ Run the CSS expression on this element and its children,
+ returning a list of the results.
+
+ Equivalent to lxml.cssselect.CSSSelect(expr)(self) -- note
+ that pre-compiling the expression can provide a substantial
+ speedup.
+ """
+ # Do the import here to make the dependency optional.
+ from lxml.cssselect import CSSSelector
+ return CSSSelector(expr, translator=translator)(self)
+
+
+@cython.linetrace(False)
+cdef _Element _elementFactory(_Document doc, xmlNode* c_node):
+ cdef _Element result
+ result = getProxy(c_node)
+ if result is not None:
+ return result
+ if c_node is NULL:
+ return None
+
+ element_class = LOOKUP_ELEMENT_CLASS(
+ ELEMENT_CLASS_LOOKUP_STATE, doc, c_node)
+ if type(element_class) is not type:
+ if not isinstance(element_class, type):
+ raise TypeError(f"Element class is not a type, got {type(element_class)}")
+ if hasProxy(c_node):
+ # prevent re-entry race condition - we just called into Python
+ return getProxy(c_node)
+ result = element_class.__new__(element_class)
+ if hasProxy(c_node):
+ # prevent re-entry race condition - we just called into Python
+ result._c_node = NULL
+ return getProxy(c_node)
+
+ _registerProxy(result, doc, c_node)
+ if element_class is not _Element:
+ result._init()
+ return result
+
+
+@cython.internal
+cdef class __ContentOnlyElement(_Element):
+ cdef int _raiseImmutable(self) except -1:
+ raise TypeError, "this element does not have children or attributes"
+
+ def set(self, key, value):
+ "set(self, key, value)"
+ self._raiseImmutable()
+
+ def append(self, value):
+ "append(self, value)"
+ self._raiseImmutable()
+
+ def insert(self, index, value):
+ "insert(self, index, value)"
+ self._raiseImmutable()
+
+ def __setitem__(self, index, value):
+ "__setitem__(self, index, value)"
+ self._raiseImmutable()
+
+ @property
+ def attrib(self):
+ return IMMUTABLE_EMPTY_MAPPING
+
+ property text:
+ def __get__(self):
+ _assertValidNode(self)
+ return funicodeOrEmpty(self._c_node.content)
+
+ def __set__(self, value):
+ cdef tree.xmlDict* c_dict
+ _assertValidNode(self)
+ if value is None:
+ c_text = NULL
+ else:
+ value = _utf8(value)
+ c_text = _xcstr(value)
+ tree.xmlNodeSetContent(self._c_node, c_text)
+
+ # ACCESSORS
+ def __getitem__(self, x):
+ "__getitem__(self, x)"
+ if isinstance(x, slice):
+ return []
+ else:
+ raise IndexError, "list index out of range"
+
+ def __len__(self):
+ "__len__(self)"
+ return 0
+
+ def get(self, key, default=None):
+ "get(self, key, default=None)"
+ return None
+
+ def keys(self):
+ "keys(self)"
+ return []
+
+ def items(self):
+ "items(self)"
+ return []
+
+ def values(self):
+ "values(self)"
+ return []
+
+cdef class _Comment(__ContentOnlyElement):
+ @property
+ def tag(self):
+ return Comment
+
+ def __repr__(self):
+ return "" % self.text
+
+cdef class _ProcessingInstruction(__ContentOnlyElement):
+ @property
+ def tag(self):
+ return ProcessingInstruction
+
+ property target:
+ # not in ElementTree
+ def __get__(self):
+ _assertValidNode(self)
+ return funicode(self._c_node.name)
+
+ def __set__(self, value):
+ _assertValidNode(self)
+ value = _utf8(value)
+ c_text = _xcstr(value)
+ tree.xmlNodeSetName(self._c_node, c_text)
+
+ def __repr__(self):
+ text = self.text
+ if text:
+ return "%s %s?>" % (self.target, text)
+ else:
+ return "%s?>" % self.target
+
+ def get(self, key, default=None):
+ """get(self, key, default=None)
+
+ Try to parse pseudo-attributes from the text content of the
+ processing instruction, search for one with the given key as
+ name and return its associated value.
+
+ Note that this is only a convenience method for the most
+ common case that all text content is structured in
+ attribute-like name-value pairs with properly quoted values.
+ It is not guaranteed to work for all possible text content.
+ """
+ return self.attrib.get(key, default)
+
+ @property
+ def attrib(self):
+ """Returns a dict containing all pseudo-attributes that can be
+ parsed from the text content of this processing instruction.
+ Note that modifying the dict currently has no effect on the
+ XML node, although this is not guaranteed to stay this way.
+ """
+ return { attr : (value1 or value2)
+ for attr, value1, value2 in _FIND_PI_ATTRIBUTES(' ' + self.text) }
+
+cdef object _FIND_PI_ATTRIBUTES = re.compile(r'\s+(\w+)\s*=\s*(?:\'([^\']*)\'|"([^"]*)")', re.U).findall
+
+cdef class _Entity(__ContentOnlyElement):
+ @property
+ def tag(self):
+ return Entity
+
+ property name:
+ # not in ElementTree
+ def __get__(self):
+ _assertValidNode(self)
+ return funicode(self._c_node.name)
+
+ def __set__(self, value):
+ _assertValidNode(self)
+ value_utf = _utf8(value)
+ if b'&' in value_utf or b';' in value_utf:
+ raise ValueError, f"Invalid entity name '{value}'"
+ tree.xmlNodeSetName(self._c_node, _xcstr(value_utf))
+
+ @property
+ def text(self):
+ # FIXME: should this be None or '&[VALUE];' or the resolved
+ # entity value ?
+ _assertValidNode(self)
+ return f'&{funicode(self._c_node.name)};'
+
+ def __repr__(self):
+ return "&%s;" % self.name
+
+
+cdef class QName:
+ """QName(text_or_uri_or_element, tag=None)
+
+ QName wrapper for qualified XML names.
+
+ Pass a tag name by itself or a namespace URI and a tag name to
+ create a qualified name. Alternatively, pass an Element to
+ extract its tag name. ``None`` as first argument is ignored in
+ order to allow for generic 2-argument usage.
+
+ The ``text`` property holds the qualified name in
+ ``{namespace}tagname`` notation. The ``namespace`` and
+ ``localname`` properties hold the respective parts of the tag
+ name.
+
+ You can pass QName objects wherever a tag name is expected. Also,
+ setting Element text from a QName will resolve the namespace prefix
+ on assignment and set a qualified text value. This is helpful in XML
+ languages like SOAP or XML-Schema that use prefixed tag names in
+ their text content.
+ """
+ cdef readonly unicode text
+ cdef readonly unicode localname
+ cdef readonly unicode namespace
+ def __init__(self, text_or_uri_or_element, tag=None):
+ if text_or_uri_or_element is None:
+ # Allow None as no namespace.
+ text_or_uri_or_element, tag = tag, None
+ if not _isString(text_or_uri_or_element):
+ if isinstance(text_or_uri_or_element, _Element):
+ text_or_uri_or_element = (<_Element>text_or_uri_or_element).tag
+ if not _isString(text_or_uri_or_element):
+ raise ValueError, f"Invalid input tag of type {type(text_or_uri_or_element)!r}"
+ elif isinstance(text_or_uri_or_element, QName):
+ text_or_uri_or_element = (text_or_uri_or_element).text
+ elif text_or_uri_or_element is not None:
+ text_or_uri_or_element = unicode(text_or_uri_or_element)
+ else:
+ raise ValueError, f"Invalid input tag of type {type(text_or_uri_or_element)!r}"
+
+ ns_utf, tag_utf = _getNsTag(text_or_uri_or_element)
+ if tag is not None:
+ # either ('ns', 'tag') or ('{ns}oldtag', 'newtag')
+ if ns_utf is None:
+ ns_utf = tag_utf # case 1: namespace ended up as tag name
+ tag_utf = _utf8(tag)
+ _tagValidOrRaise(tag_utf)
+ self.localname = (tag_utf).decode('utf8')
+ if ns_utf is None:
+ self.namespace = None
+ self.text = self.localname
+ else:
+ self.namespace = (ns_utf).decode('utf8')
+ self.text = "{%s}%s" % (self.namespace, self.localname)
+ def __str__(self):
+ return self.text
+ def __hash__(self):
+ return hash(self.text)
+ def __richcmp__(self, other, int op):
+ try:
+ if type(other) is QName:
+ other = (other).text
+ elif not isinstance(other, unicode):
+ other = unicode(other)
+ except (ValueError, UnicodeDecodeError):
+ return NotImplemented
+ return python.PyObject_RichCompare(self.text, other, op)
+
+
+cdef public class _ElementTree [ type LxmlElementTreeType,
+ object LxmlElementTree ]:
+ cdef _Document _doc
+ cdef _Element _context_node
+
+ # Note that _doc is only used to store the original document if we do not
+ # have a _context_node. All methods should prefer self._context_node._doc
+ # to honour tree restructuring. _doc can happily be None!
+
+ @cython.final
+ cdef int _assertHasRoot(self) except -1:
+ """We have to take care here: the document may not have a root node!
+ This can happen if ElementTree() is called without any argument and
+ the caller 'forgets' to call parse() afterwards, so this is a bug in
+ the caller program.
+ """
+ assert self._context_node is not None, \
+ "ElementTree not initialized, missing root"
+ return 0
+
+ def parse(self, source, _BaseParser parser=None, *, base_url=None):
+ """parse(self, source, parser=None, base_url=None)
+
+ Updates self with the content of source and returns its root.
+ """
+ cdef _Document doc = None
+ try:
+ doc = _parseDocument(source, parser, base_url)
+ except _TargetParserResult as result_container:
+ # raises a TypeError if we don't get an _Element
+ self._context_node = result_container.result
+ else:
+ self._context_node = doc.getroot()
+ self._doc = None if self._context_node is not None else doc
+ return self._context_node
+
+ def _setroot(self, _Element root not None):
+ """_setroot(self, root)
+
+ Relocate the ElementTree to a new root node.
+ """
+ _assertValidNode(root)
+ if root._c_node.type != tree.XML_ELEMENT_NODE:
+ raise TypeError, "Only elements can be the root of an ElementTree"
+ self._context_node = root
+ self._doc = None
+
+ def getroot(self):
+ """getroot(self)
+
+ Gets the root element for this tree.
+ """
+ return self._context_node
+
+ def __copy__(self):
+ return _elementTreeFactory(self._doc, self._context_node)
+
+ def __deepcopy__(self, memo):
+ cdef _Element root
+ cdef _Document doc
+ cdef xmlDoc* c_doc
+ if self._context_node is not None:
+ root = self._context_node.__copy__()
+ assert root is not None
+ _assertValidNode(root)
+ _copyNonElementSiblings(self._context_node._c_node, root._c_node)
+ return _elementTreeFactory(None, root)
+ elif self._doc is not None:
+ _assertValidDoc(self._doc)
+ c_doc = tree.xmlCopyDoc(self._doc._c_doc, 1)
+ if c_doc is NULL:
+ raise MemoryError()
+ doc = _documentFactory(c_doc, self._doc._parser)
+ return _elementTreeFactory(doc, None)
+ else:
+ # so what ...
+ return self
+
+ # not in ElementTree
+ @property
+ def docinfo(self) -> DocInfo:
+ """Information about the document provided by parser and DTD."""
+ self._assertHasRoot()
+ return DocInfo(self._context_node._doc)
+
+ # not in ElementTree, read-only
+ @property
+ def parser(self):
+ """The parser that was used to parse the document in this ElementTree.
+ """
+ if self._context_node is not None and \
+ self._context_node._doc is not None:
+ return self._context_node._doc._parser
+ if self._doc is not None:
+ return self._doc._parser
+ return None
+
+ def write(self, file, *, encoding=None, method="xml",
+ bint pretty_print=False, xml_declaration=None, bint with_tail=True,
+ standalone=None, doctype=None, compression=0,
+ bint exclusive=False, inclusive_ns_prefixes=None,
+ bint with_comments=True, bint strip_text=False,
+ docstring=None):
+ """write(self, file, encoding=None, method="xml",
+ pretty_print=False, xml_declaration=None, with_tail=True,
+ standalone=None, doctype=None, compression=0,
+ exclusive=False, inclusive_ns_prefixes=None,
+ with_comments=True, strip_text=False)
+
+ Write the tree to a filename, file or file-like object.
+
+ Defaults to ASCII encoding and writing a declaration as needed.
+
+ The keyword argument 'method' selects the output method:
+ 'xml', 'html', 'text', 'c14n' or 'c14n2'. Default is 'xml'.
+
+ With ``method="c14n"`` (C14N version 1), the options ``exclusive``,
+ ``with_comments`` and ``inclusive_ns_prefixes`` request exclusive
+ C14N, include comments, and list the inclusive prefixes respectively.
+
+ With ``method="c14n2"`` (C14N version 2), the ``with_comments`` and
+ ``strip_text`` options control the output of comments and text space
+ according to C14N 2.0.
+
+ Passing a boolean value to the ``standalone`` option will
+ output an XML declaration with the corresponding
+ ``standalone`` flag.
+
+ The ``doctype`` option allows passing in a plain string that will
+ be serialised before the XML tree. Note that passing in non
+ well-formed content here will make the XML output non well-formed.
+ Also, an existing doctype in the document tree will not be removed
+ when serialising an ElementTree instance.
+
+ The ``compression`` option enables GZip compression level 1-9.
+
+ The ``inclusive_ns_prefixes`` should be a list of namespace strings
+ (i.e. ['xs', 'xsi']) that will be promoted to the top-level element
+ during exclusive C14N serialisation. This parameter is ignored if
+ exclusive mode=False.
+
+ If exclusive=True and no list is provided, a namespace will only be
+ rendered if it is used by the immediate parent or one of its attributes
+ and its prefix and values have not already been rendered by an ancestor
+ of the namespace node's parent element.
+ """
+ cdef bint write_declaration
+ cdef int is_standalone
+
+ self._assertHasRoot()
+ _assertValidNode(self._context_node)
+ if compression is None or compression < 0:
+ compression = 0
+
+ # C14N serialisation
+ if method in ('c14n', 'c14n2'):
+ if encoding is not None:
+ raise ValueError("Cannot specify encoding with C14N")
+ if xml_declaration:
+ raise ValueError("Cannot enable XML declaration in C14N")
+
+ if method == 'c14n':
+ _tofilelikeC14N(file, self._context_node, exclusive, with_comments,
+ compression, inclusive_ns_prefixes)
+ else: # c14n2
+ with _open_utf8_file(file, compression=compression) as f:
+ target = C14NWriterTarget(
+ f.write, with_comments=with_comments, strip_text=strip_text)
+ _tree_to_target(self, target)
+ return
+
+ if not with_comments:
+ raise ValueError("Can only discard comments in C14N serialisation")
+ # suppress decl. in default case (purely for ElementTree compatibility)
+ if xml_declaration is not None:
+ write_declaration = xml_declaration
+ if encoding is None:
+ encoding = 'ASCII'
+ else:
+ encoding = encoding.upper()
+ elif encoding is None:
+ encoding = 'ASCII'
+ write_declaration = 0
+ else:
+ encoding = encoding.upper()
+ write_declaration = encoding not in (
+ 'US-ASCII', 'ASCII', 'UTF8', 'UTF-8')
+ if standalone is None:
+ is_standalone = -1
+ elif standalone:
+ write_declaration = 1
+ is_standalone = 1
+ else:
+ write_declaration = 1
+ is_standalone = 0
+
+ if docstring is not None and doctype is None:
+ import warnings
+ warnings.warn(
+ "The 'docstring' option is deprecated. Use 'doctype' instead.",
+ DeprecationWarning)
+ doctype = docstring
+
+ _tofilelike(file, self._context_node, encoding, doctype, method,
+ write_declaration, 1, pretty_print, with_tail,
+ is_standalone, compression)
+
+ def getpath(self, _Element element not None):
+ """getpath(self, element)
+
+ Returns a structural, absolute XPath expression to find the element.
+
+ For namespaced elements, the expression uses prefixes from the
+ document, which therefore need to be provided in order to make any
+ use of the expression in XPath.
+
+ Also see the method getelementpath(self, element), which returns a
+ self-contained ElementPath expression.
+ """
+ cdef _Document doc
+ cdef _Element root
+ cdef xmlDoc* c_doc
+ _assertValidNode(element)
+ if self._context_node is not None:
+ root = self._context_node
+ doc = root._doc
+ elif self._doc is not None:
+ doc = self._doc
+ root = doc.getroot()
+ else:
+ raise ValueError, "Element is not in this tree."
+ _assertValidDoc(doc)
+ _assertValidNode(root)
+ if element._doc is not doc:
+ raise ValueError, "Element is not in this tree."
+
+ c_doc = _fakeRootDoc(doc._c_doc, root._c_node)
+ c_path = tree.xmlGetNodePath(element._c_node)
+ _destroyFakeDoc(doc._c_doc, c_doc)
+ if c_path is NULL:
+ raise MemoryError()
+ path = funicode(c_path)
+ tree.xmlFree(c_path)
+ return path
+
+ def getelementpath(self, _Element element not None):
+ """getelementpath(self, element)
+
+ Returns a structural, absolute ElementPath expression to find the
+ element. This path can be used in the .find() method to look up
+ the element, provided that the elements along the path and their
+ list of immediate children were not modified in between.
+
+ ElementPath has the advantage over an XPath expression (as returned
+ by the .getpath() method) that it does not require additional prefix
+ declarations. It is always self-contained.
+ """
+ cdef _Element root
+ cdef Py_ssize_t count
+ _assertValidNode(element)
+ if element._c_node.type != tree.XML_ELEMENT_NODE:
+ raise ValueError, "input is not an Element"
+ if self._context_node is not None:
+ root = self._context_node
+ elif self._doc is not None:
+ root = self._doc.getroot()
+ else:
+ raise ValueError, "Element is not in this tree"
+ _assertValidNode(root)
+ if element._doc is not root._doc:
+ raise ValueError, "Element is not in this tree"
+
+ path = []
+ c_element = element._c_node
+ while c_element is not root._c_node:
+ c_name = c_element.name
+ c_href = _getNs(c_element)
+ tag = _namespacedNameFromNsName(c_href, c_name)
+ if c_href is NULL:
+ c_href = b'' # no namespace (NULL is wildcard)
+ # use tag[N] if there are preceding siblings with the same tag
+ count = 0
+ c_node = c_element.prev
+ while c_node is not NULL:
+ if c_node.type == tree.XML_ELEMENT_NODE:
+ if _tagMatches(c_node, c_href, c_name):
+ count += 1
+ c_node = c_node.prev
+ if count:
+ tag = f'{tag}[{count+1}]'
+ else:
+ # use tag[1] if there are following siblings with the same tag
+ c_node = c_element.next
+ while c_node is not NULL:
+ if c_node.type == tree.XML_ELEMENT_NODE:
+ if _tagMatches(c_node, c_href, c_name):
+ tag += '[1]'
+ break
+ c_node = c_node.next
+
+ path.append(tag)
+ c_element = c_element.parent
+ if c_element is NULL or c_element.type != tree.XML_ELEMENT_NODE:
+ raise ValueError, "Element is not in this tree."
+ if not path:
+ return '.'
+ path.reverse()
+ return '/'.join(path)
+
+ def getiterator(self, tag=None, *tags):
+ """getiterator(self, *tags, tag=None)
+
+ Returns a sequence or iterator of all elements in document order
+ (depth first pre-order), starting with the root element.
+
+ Can be restricted to find only elements with specific tags,
+ see `_Element.iter`.
+
+ :deprecated: Note that this method is deprecated as of
+ ElementTree 1.3 and lxml 2.0. It returns an iterator in
+ lxml, which diverges from the original ElementTree
+ behaviour. If you want an efficient iterator, use the
+ ``tree.iter()`` method instead. You should only use this
+ method in new code if you require backwards compatibility
+ with older versions of lxml or ElementTree.
+ """
+ root = self.getroot()
+ if root is None:
+ return ITER_EMPTY
+ if tag is not None:
+ tags += (tag,)
+ return root.getiterator(*tags)
+
+ def iter(self, tag=None, *tags):
+ """iter(self, tag=None, *tags)
+
+ Creates an iterator for the root element. The iterator loops over
+ all elements in this tree, in document order. Note that siblings
+ of the root element (comments or processing instructions) are not
+ returned by the iterator.
+
+ Can be restricted to find only elements with specific tags,
+ see `_Element.iter`.
+ """
+ root = self.getroot()
+ if root is None:
+ return ITER_EMPTY
+ if tag is not None:
+ tags += (tag,)
+ return root.iter(*tags)
+
+ def find(self, path, namespaces=None):
+ """find(self, path, namespaces=None)
+
+ Finds the first toplevel element with given tag. Same as
+ ``tree.getroot().find(path)``.
+
+ The optional ``namespaces`` argument accepts a
+ prefix-to-namespace mapping that allows the usage of XPath
+ prefixes in the path expression.
+ """
+ self._assertHasRoot()
+ root = self.getroot()
+ if _isString(path):
+ if path[:1] == "/":
+ path = "." + path
+ from warnings import warn
+ warn(
+ "This search incorrectly ignores the root element, and will be "
+ "fixed in a future version. If you rely on the current "
+ f"behaviour, change it to {path!r}",
+ FutureWarning, stacklevel=1
+ )
+ return root.find(path, namespaces)
+
+ def findtext(self, path, default=None, namespaces=None):
+ """findtext(self, path, default=None, namespaces=None)
+
+ Finds the text for the first element matching the ElementPath
+ expression. Same as getroot().findtext(path)
+
+ The optional ``namespaces`` argument accepts a
+ prefix-to-namespace mapping that allows the usage of XPath
+ prefixes in the path expression.
+ """
+ self._assertHasRoot()
+ root = self.getroot()
+ if _isString(path):
+ if path[:1] == "/":
+ path = "." + path
+ from warnings import warn
+ warn(
+ "This search incorrectly ignores the root element, and will be "
+ "fixed in a future version. If you rely on the current "
+ f"behaviour, change it to {path!r}",
+ FutureWarning, stacklevel=1
+ )
+ return root.findtext(path, default, namespaces)
+
+ def findall(self, path, namespaces=None):
+ """findall(self, path, namespaces=None)
+
+ Finds all elements matching the ElementPath expression. Same as
+ getroot().findall(path).
+
+ The optional ``namespaces`` argument accepts a
+ prefix-to-namespace mapping that allows the usage of XPath
+ prefixes in the path expression.
+ """
+ self._assertHasRoot()
+ root = self.getroot()
+ if _isString(path):
+ if path[:1] == "/":
+ path = "." + path
+ from warnings import warn
+ warn(
+ "This search incorrectly ignores the root element, and will be "
+ "fixed in a future version. If you rely on the current "
+ f"behaviour, change it to {path!r}",
+ FutureWarning, stacklevel=1
+ )
+ return root.findall(path, namespaces)
+
+ def iterfind(self, path, namespaces=None):
+ """iterfind(self, path, namespaces=None)
+
+ Iterates over all elements matching the ElementPath expression.
+ Same as getroot().iterfind(path).
+
+ The optional ``namespaces`` argument accepts a
+ prefix-to-namespace mapping that allows the usage of XPath
+ prefixes in the path expression.
+ """
+ self._assertHasRoot()
+ root = self.getroot()
+ if _isString(path):
+ if path[:1] == "/":
+ path = "." + path
+ from warnings import warn
+ warn(
+ "This search incorrectly ignores the root element, and will be "
+ "fixed in a future version. If you rely on the current "
+ f"behaviour, change it to {path!r}",
+ FutureWarning, stacklevel=1
+ )
+ return root.iterfind(path, namespaces)
+
+ def xpath(self, _path, *, namespaces=None, extensions=None,
+ smart_strings=True, **_variables):
+ """xpath(self, _path, namespaces=None, extensions=None, smart_strings=True, **_variables)
+
+ XPath evaluate in context of document.
+
+ ``namespaces`` is an optional dictionary with prefix to namespace URI
+ mappings, used by XPath. ``extensions`` defines additional extension
+ functions.
+
+ Returns a list (nodeset), or bool, float or string.
+
+ In case of a list result, return Element for element nodes,
+ string for text and attribute values.
+
+ Note: if you are going to apply multiple XPath expressions
+ against the same document, it is more efficient to use
+ XPathEvaluator directly.
+ """
+ self._assertHasRoot()
+ evaluator = XPathDocumentEvaluator(self, namespaces=namespaces,
+ extensions=extensions,
+ smart_strings=smart_strings)
+ return evaluator(_path, **_variables)
+
+ def xslt(self, _xslt, extensions=None, access_control=None, **_kw):
+ """xslt(self, _xslt, extensions=None, access_control=None, **_kw)
+
+ Transform this document using other document.
+
+ xslt is a tree that should be XSLT
+ keyword parameters are XSLT transformation parameters.
+
+ Returns the transformed tree.
+
+ Note: if you are going to apply the same XSLT stylesheet against
+ multiple documents, it is more efficient to use the XSLT
+ class directly.
+ """
+ self._assertHasRoot()
+ style = XSLT(_xslt, extensions=extensions,
+ access_control=access_control)
+ return style(self, **_kw)
+
+ def relaxng(self, relaxng):
+ """relaxng(self, relaxng)
+
+ Validate this document using other document.
+
+ The relaxng argument is a tree that should contain a Relax NG schema.
+
+ Returns True or False, depending on whether validation
+ succeeded.
+
+ Note: if you are going to apply the same Relax NG schema against
+ multiple documents, it is more efficient to use the RelaxNG
+ class directly.
+ """
+ self._assertHasRoot()
+ schema = RelaxNG(relaxng)
+ return schema.validate(self)
+
+ def xmlschema(self, xmlschema):
+ """xmlschema(self, xmlschema)
+
+ Validate this document using other document.
+
+ The xmlschema argument is a tree that should contain an XML Schema.
+
+ Returns True or False, depending on whether validation
+ succeeded.
+
+ Note: If you are going to apply the same XML Schema against
+ multiple documents, it is more efficient to use the XMLSchema
+ class directly.
+ """
+ self._assertHasRoot()
+ schema = XMLSchema(xmlschema)
+ return schema.validate(self)
+
+ def xinclude(self):
+ """xinclude(self)
+
+ Process the XInclude nodes in this document and include the
+ referenced XML fragments.
+
+ There is support for loading files through the file system, HTTP and
+ FTP.
+
+ Note that XInclude does not support custom resolvers in Python space
+ due to restrictions of libxml2 <= 2.6.29.
+ """
+ self._assertHasRoot()
+ XInclude()(self._context_node)
+
+ def write_c14n(self, file, *, bint exclusive=False, bint with_comments=True,
+ compression=0, inclusive_ns_prefixes=None):
+ """write_c14n(self, file, exclusive=False, with_comments=True,
+ compression=0, inclusive_ns_prefixes=None)
+
+ C14N write of document. Always writes UTF-8.
+
+ The ``compression`` option enables GZip compression level 1-9.
+
+ The ``inclusive_ns_prefixes`` should be a list of namespace strings
+ (i.e. ['xs', 'xsi']) that will be promoted to the top-level element
+ during exclusive C14N serialisation. This parameter is ignored if
+ exclusive mode=False.
+
+ If exclusive=True and no list is provided, a namespace will only be
+ rendered if it is used by the immediate parent or one of its attributes
+ and its prefix and values have not already been rendered by an ancestor
+ of the namespace node's parent element.
+
+ NOTE: This method is deprecated as of lxml 4.4 and will be removed in a
+ future release. Use ``.write(f, method="c14n")`` instead.
+ """
+ self._assertHasRoot()
+ _assertValidNode(self._context_node)
+ if compression is None or compression < 0:
+ compression = 0
+
+ _tofilelikeC14N(file, self._context_node, exclusive, with_comments,
+ compression, inclusive_ns_prefixes)
+
+cdef _ElementTree _elementTreeFactory(_Document doc, _Element context_node):
+ return _newElementTree(doc, context_node, _ElementTree)
+
+cdef _ElementTree _newElementTree(_Document doc, _Element context_node,
+ object baseclass):
+ cdef _ElementTree result
+ result = baseclass()
+ if context_node is None and doc is not None:
+ context_node = doc.getroot()
+ if context_node is None:
+ _assertValidDoc(doc)
+ result._doc = doc
+ else:
+ _assertValidNode(context_node)
+ result._context_node = context_node
+ return result
+
+
+@cython.final
+@cython.freelist(16)
+cdef class _Attrib:
+ """A dict-like proxy for the ``Element.attrib`` property.
+ """
+ cdef _Element _element
+ def __cinit__(self, _Element element not None):
+ _assertValidNode(element)
+ self._element = element
+
+ # MANIPULATORS
+ def __setitem__(self, key, value):
+ _assertValidNode(self._element)
+ _setAttributeValue(self._element, key, value)
+
+ def __delitem__(self, key):
+ _assertValidNode(self._element)
+ _delAttribute(self._element, key)
+
+ def update(self, sequence_or_dict):
+ _assertValidNode(self._element)
+ if isinstance(sequence_or_dict, (dict, _Attrib)):
+ sequence_or_dict = sequence_or_dict.items()
+ for key, value in sequence_or_dict:
+ _setAttributeValue(self._element, key, value)
+
+ def pop(self, key, *default):
+ if len(default) > 1:
+ raise TypeError, f"pop expected at most 2 arguments, got {len(default)+1}"
+ _assertValidNode(self._element)
+ result = _getAttributeValue(self._element, key, None)
+ if result is None:
+ if not default:
+ raise KeyError, key
+ result = default[0]
+ else:
+ _delAttribute(self._element, key)
+ return result
+
+ def clear(self):
+ _assertValidNode(self._element)
+ c_attrs = self._element._c_node.properties
+ if c_attrs:
+ self._element._c_node.properties = NULL
+ tree.xmlFreePropList(c_attrs)
+
+ # ACCESSORS
+ def __repr__(self):
+ _assertValidNode(self._element)
+ return repr(dict( _collectAttributes(self._element._c_node, 3) ))
+
+ def __copy__(self):
+ _assertValidNode(self._element)
+ return dict(_collectAttributes(self._element._c_node, 3))
+
+ def __deepcopy__(self, memo):
+ _assertValidNode(self._element)
+ return dict(_collectAttributes(self._element._c_node, 3))
+
+ def __getitem__(self, key):
+ _assertValidNode(self._element)
+ result = _getAttributeValue(self._element, key, None)
+ if result is None:
+ raise KeyError, key
+ return result
+
+ def __bool__(self):
+ _assertValidNode(self._element)
+ cdef xmlAttr* c_attr = self._element._c_node.properties
+ while c_attr is not NULL:
+ if c_attr.type == tree.XML_ATTRIBUTE_NODE:
+ return 1
+ c_attr = c_attr.next
+ return 0
+
+ def __len__(self):
+ _assertValidNode(self._element)
+ cdef xmlAttr* c_attr = self._element._c_node.properties
+ cdef Py_ssize_t c = 0
+ while c_attr is not NULL:
+ if c_attr.type == tree.XML_ATTRIBUTE_NODE:
+ c += 1
+ c_attr = c_attr.next
+ return c
+
+ def get(self, key, default=None):
+ _assertValidNode(self._element)
+ return _getAttributeValue(self._element, key, default)
+
+ def keys(self):
+ _assertValidNode(self._element)
+ return _collectAttributes(self._element._c_node, 1)
+
+ def __iter__(self):
+ _assertValidNode(self._element)
+ return iter(_collectAttributes(self._element._c_node, 1))
+
+ def iterkeys(self):
+ _assertValidNode(self._element)
+ return iter(_collectAttributes(self._element._c_node, 1))
+
+ def values(self):
+ _assertValidNode(self._element)
+ return _collectAttributes(self._element._c_node, 2)
+
+ def itervalues(self):
+ _assertValidNode(self._element)
+ return iter(_collectAttributes(self._element._c_node, 2))
+
+ def items(self):
+ _assertValidNode(self._element)
+ return _collectAttributes(self._element._c_node, 3)
+
+ def iteritems(self):
+ _assertValidNode(self._element)
+ return iter(_collectAttributes(self._element._c_node, 3))
+
+ def has_key(self, key):
+ _assertValidNode(self._element)
+ return key in self
+
+ def __contains__(self, key):
+ _assertValidNode(self._element)
+ cdef xmlNode* c_node
+ ns, tag = _getNsTag(key)
+ c_node = self._element._c_node
+ c_href = NULL if ns is None else _xcstr(ns)
+ return 1 if tree.xmlHasNsProp(c_node, _xcstr(tag), c_href) else 0
+
+ def __richcmp__(self, other, int op):
+ try:
+ one = dict(self.items())
+ if not isinstance(other, dict):
+ other = dict(other)
+ except (TypeError, ValueError):
+ return NotImplemented
+ return python.PyObject_RichCompare(one, other, op)
+
+MutableMapping.register(_Attrib)
+
+
+@cython.final
+@cython.internal
+cdef class _AttribIterator:
+ """Attribute iterator - for internal use only!
+ """
+ # XML attributes must not be removed while running!
+ cdef _Element _node
+ cdef xmlAttr* _c_attr
+ cdef int _keysvalues # 1 - keys, 2 - values, 3 - items (key, value)
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ cdef xmlAttr* c_attr
+ if self._node is None:
+ raise StopIteration
+ c_attr = self._c_attr
+ while c_attr is not NULL and c_attr.type != tree.XML_ATTRIBUTE_NODE:
+ c_attr = c_attr.next
+ if c_attr is NULL:
+ self._node = None
+ raise StopIteration
+
+ self._c_attr = c_attr.next
+ if self._keysvalues == 1:
+ return _namespacedName(c_attr)
+ elif self._keysvalues == 2:
+ return _attributeValue(self._node._c_node, c_attr)
+ else:
+ return (_namespacedName(c_attr),
+ _attributeValue(self._node._c_node, c_attr))
+
+cdef object _attributeIteratorFactory(_Element element, int keysvalues):
+ cdef _AttribIterator attribs
+ if element._c_node.properties is NULL:
+ return ITER_EMPTY
+ attribs = _AttribIterator()
+ attribs._node = element
+ attribs._c_attr = element._c_node.properties
+ attribs._keysvalues = keysvalues
+ return attribs
+
+
+cdef public class _ElementTagMatcher [ object LxmlElementTagMatcher,
+ type LxmlElementTagMatcherType ]:
+ """
+ Dead but public. :)
+ """
+ cdef object _pystrings
+ cdef int _node_type
+ cdef char* _href
+ cdef char* _name
+ cdef _initTagMatch(self, tag):
+ self._href = NULL
+ self._name = NULL
+ if tag is None:
+ self._node_type = 0
+ elif tag is Comment:
+ self._node_type = tree.XML_COMMENT_NODE
+ elif tag is ProcessingInstruction:
+ self._node_type = tree.XML_PI_NODE
+ elif tag is Entity:
+ self._node_type = tree.XML_ENTITY_REF_NODE
+ elif tag is Element:
+ self._node_type = tree.XML_ELEMENT_NODE
+ else:
+ self._node_type = tree.XML_ELEMENT_NODE
+ self._pystrings = _getNsTag(tag)
+ if self._pystrings[0] is not None:
+ self._href = _cstr(self._pystrings[0])
+ self._name = _cstr(self._pystrings[1])
+ if self._name[0] == c'*' and self._name[1] == c'\0':
+ self._name = NULL
+
+cdef public class _ElementIterator(_ElementTagMatcher) [
+ object LxmlElementIterator, type LxmlElementIteratorType ]:
+ """
+ Dead but public. :)
+ """
+ # we keep Python references here to control GC
+ cdef _Element _node
+ cdef _node_to_node_function _next_element
+ def __iter__(self):
+ return self
+
+ cdef void _storeNext(self, _Element node):
+ cdef xmlNode* c_node
+ c_node = self._next_element(node._c_node)
+ while c_node is not NULL and \
+ self._node_type != 0 and \
+ (self._node_type != c_node.type or
+ not _tagMatches(c_node, self._href, self._name)):
+ c_node = self._next_element(c_node)
+ if c_node is NULL:
+ self._node = None
+ else:
+ # Python ref:
+ self._node = _elementFactory(node._doc, c_node)
+
+ def __next__(self):
+ cdef xmlNode* c_node
+ cdef _Element current_node
+ if self._node is None:
+ raise StopIteration
+ # Python ref:
+ current_node = self._node
+ self._storeNext(current_node)
+ return current_node
+
+@cython.final
+@cython.internal
+cdef class _MultiTagMatcher:
+ """
+ Match an xmlNode against a list of tags.
+ """
+ cdef list _py_tags
+ cdef qname* _cached_tags
+ cdef size_t _tag_count
+ cdef size_t _cached_size
+ cdef _Document _cached_doc
+ cdef int _node_types
+
+ def __cinit__(self, tags):
+ self._py_tags = []
+ self.initTagMatch(tags)
+
+ def __dealloc__(self):
+ self._clear()
+
+ cdef bint rejectsAll(self) noexcept:
+ return not self._tag_count and not self._node_types
+
+ cdef bint rejectsAllAttributes(self) noexcept:
+ return not self._tag_count
+
+ cdef bint matchesType(self, int node_type) noexcept:
+ if node_type == tree.XML_ELEMENT_NODE and self._tag_count:
+ return True
+ return self._node_types & (1 << node_type)
+
+ cdef void _clear(self) noexcept:
+ cdef size_t i, count
+ count = self._tag_count
+ self._tag_count = 0
+ if self._cached_tags:
+ for i in range(count):
+ cpython.ref.Py_XDECREF(self._cached_tags[i].href)
+ python.lxml_free(self._cached_tags)
+ self._cached_tags = NULL
+
+ cdef initTagMatch(self, tags):
+ self._cached_doc = None
+ del self._py_tags[:]
+ self._clear()
+ if tags is None or tags == ():
+ # no selection in tags argument => match anything
+ self._node_types = (
+ 1 << tree.XML_COMMENT_NODE |
+ 1 << tree.XML_PI_NODE |
+ 1 << tree.XML_ENTITY_REF_NODE |
+ 1 << tree.XML_ELEMENT_NODE)
+ else:
+ self._node_types = 0
+ self._storeTags(tags, set())
+
+ cdef _storeTags(self, tag, set seen):
+ if tag is Comment:
+ self._node_types |= 1 << tree.XML_COMMENT_NODE
+ elif tag is ProcessingInstruction:
+ self._node_types |= 1 << tree.XML_PI_NODE
+ elif tag is Entity:
+ self._node_types |= 1 << tree.XML_ENTITY_REF_NODE
+ elif tag is Element:
+ self._node_types |= 1 << tree.XML_ELEMENT_NODE
+ elif python._isString(tag):
+ if tag in seen:
+ return
+ seen.add(tag)
+ if tag in ('*', '{*}*'):
+ self._node_types |= 1 << tree.XML_ELEMENT_NODE
+ else:
+ href, name = _getNsTag(tag)
+ if name == b'*':
+ name = None
+ if href is None:
+ href = b'' # no namespace
+ elif href == b'*':
+ href = None # wildcard: any namespace, including none
+ self._py_tags.append((href, name))
+ elif isinstance(tag, QName):
+ self._storeTags(tag.text, seen)
+ else:
+ # support a sequence of tags
+ for item in tag:
+ self._storeTags(item, seen)
+
+ cdef inline int cacheTags(self, _Document doc, bint force_into_dict=False) except -1:
+ """
+ Look up the tag names in the doc dict to enable string pointer comparisons.
+ """
+ cdef size_t dict_size = tree.xmlDictSize(doc._c_doc.dict)
+ if doc is self._cached_doc and dict_size == self._cached_size:
+ # doc and dict didn't change => names already cached
+ return 0
+ self._tag_count = 0
+ if not self._py_tags:
+ self._cached_doc = doc
+ self._cached_size = dict_size
+ return 0
+ if not self._cached_tags:
+ self._cached_tags = python.lxml_malloc(len(self._py_tags), sizeof(qname))
+ if not self._cached_tags:
+ self._cached_doc = None
+ raise MemoryError()
+ self._tag_count = _mapTagsToQnameMatchArray(
+ doc._c_doc, self._py_tags, self._cached_tags, force_into_dict)
+ self._cached_doc = doc
+ self._cached_size = dict_size
+ return 0
+
+ cdef inline bint matches(self, xmlNode* c_node) noexcept:
+ cdef qname* c_qname
+ if self._node_types & (1 << c_node.type):
+ return True
+ elif c_node.type == tree.XML_ELEMENT_NODE:
+ for c_qname in self._cached_tags[:self._tag_count]:
+ if _tagMatchesExactly(c_node, c_qname):
+ return True
+ return False
+
+ cdef inline bint matchesNsTag(self, const_xmlChar* c_href,
+ const_xmlChar* c_name) noexcept:
+ cdef qname* c_qname
+ if self._node_types & (1 << tree.XML_ELEMENT_NODE):
+ return True
+ for c_qname in self._cached_tags[:self._tag_count]:
+ if _nsTagMatchesExactly(c_href, c_name, c_qname):
+ return True
+ return False
+
+ cdef inline bint matchesAttribute(self, xmlAttr* c_attr) noexcept:
+ """Attribute matches differ from Element matches in that they do
+ not care about node types.
+ """
+ cdef qname* c_qname
+ for c_qname in self._cached_tags[:self._tag_count]:
+ if _tagMatchesExactly(c_attr, c_qname):
+ return True
+ return False
+
+cdef class _ElementMatchIterator:
+ cdef _Element _node
+ cdef _node_to_node_function _next_element
+ cdef _MultiTagMatcher _matcher
+
+ @cython.final
+ cdef _initTagMatcher(self, tags):
+ self._matcher = _MultiTagMatcher.__new__(_MultiTagMatcher, tags)
+
+ def __iter__(self):
+ return self
+
+ @cython.final
+ cdef int _storeNext(self, _Element node) except -1:
+ self._matcher.cacheTags(node._doc)
+ c_node = self._next_element(node._c_node)
+ while c_node is not NULL and not self._matcher.matches(c_node):
+ c_node = self._next_element(c_node)
+ # store Python ref to next node to make sure it's kept alive
+ self._node = _elementFactory(node._doc, c_node) if c_node is not NULL else None
+ return 0
+
+ def __next__(self):
+ cdef _Element current_node = self._node
+ if current_node is None:
+ raise StopIteration
+ self._storeNext(current_node)
+ return current_node
+
+cdef class ElementChildIterator(_ElementMatchIterator):
+ """ElementChildIterator(self, node, tag=None, reversed=False)
+ Iterates over the children of an element.
+ """
+ def __cinit__(self, _Element node not None, tag=None, *, bint reversed=False):
+ cdef xmlNode* c_node
+ _assertValidNode(node)
+ self._initTagMatcher(tag)
+ if reversed:
+ c_node = _findChildBackwards(node._c_node, 0)
+ self._next_element = _previousElement
+ else:
+ c_node = _findChildForwards(node._c_node, 0)
+ self._next_element = _nextElement
+ self._matcher.cacheTags(node._doc)
+ while c_node is not NULL and not self._matcher.matches(c_node):
+ c_node = self._next_element(c_node)
+ # store Python ref to next node to make sure it's kept alive
+ self._node = _elementFactory(node._doc, c_node) if c_node is not NULL else None
+
+cdef class SiblingsIterator(_ElementMatchIterator):
+ """SiblingsIterator(self, node, tag=None, preceding=False)
+ Iterates over the siblings of an element.
+
+ You can pass the boolean keyword ``preceding`` to specify the direction.
+ """
+ def __cinit__(self, _Element node not None, tag=None, *, bint preceding=False):
+ _assertValidNode(node)
+ self._initTagMatcher(tag)
+ if preceding:
+ self._next_element = _previousElement
+ else:
+ self._next_element = _nextElement
+ self._storeNext(node)
+
+cdef class AncestorsIterator(_ElementMatchIterator):
+ """AncestorsIterator(self, node, tag=None)
+ Iterates over the ancestors of an element (from parent to parent).
+ """
+ def __cinit__(self, _Element node not None, tag=None):
+ _assertValidNode(node)
+ self._initTagMatcher(tag)
+ self._next_element = _parentElement
+ self._storeNext(node)
+
+cdef class ElementDepthFirstIterator:
+ """ElementDepthFirstIterator(self, node, tag=None, inclusive=True)
+ Iterates over an element and its sub-elements in document order (depth
+ first pre-order).
+
+ Note that this also includes comments, entities and processing
+ instructions. To filter them out, check if the ``tag`` property
+ of the returned element is a string (i.e. not None and not a
+ factory function), or pass the ``Element`` factory for the ``tag``
+ argument to receive only Elements.
+
+ If the optional ``tag`` argument is not None, the iterator returns only
+ the elements that match the respective name and namespace.
+
+ The optional boolean argument 'inclusive' defaults to True and can be set
+ to False to exclude the start element itself.
+
+ Note that the behaviour of this iterator is completely undefined if the
+ tree it traverses is modified during iteration.
+ """
+ # we keep Python references here to control GC
+ # keep the next Element after the one we return, and the (s)top node
+ cdef _Element _next_node
+ cdef _Element _top_node
+ cdef _MultiTagMatcher _matcher
+ def __cinit__(self, _Element node not None, tag=None, *, bint inclusive=True):
+ _assertValidNode(node)
+ self._top_node = node
+ self._next_node = node
+ self._matcher = _MultiTagMatcher.__new__(_MultiTagMatcher, tag)
+ self._matcher.cacheTags(node._doc)
+ if not inclusive or not self._matcher.matches(node._c_node):
+ # find start node (this cannot raise StopIteration, self._next_node != None)
+ next(self)
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ cdef xmlNode* c_node
+ cdef _Element current_node = self._next_node
+ if current_node is None:
+ raise StopIteration
+ c_node = current_node._c_node
+ self._matcher.cacheTags(current_node._doc)
+ if not self._matcher._tag_count:
+ # no tag name was found in the dict => not in document either
+ # try to match by node type
+ c_node = self._nextNodeAnyTag(c_node)
+ else:
+ c_node = self._nextNodeMatchTag(c_node)
+ if c_node is NULL:
+ self._next_node = None
+ else:
+ self._next_node = _elementFactory(current_node._doc, c_node)
+ return current_node
+
+ @cython.final
+ cdef xmlNode* _nextNodeAnyTag(self, xmlNode* c_node) noexcept:
+ cdef int node_types = self._matcher._node_types
+ if not node_types:
+ return NULL
+ tree.BEGIN_FOR_EACH_ELEMENT_FROM(self._top_node._c_node, c_node, 0)
+ if node_types & (1 << c_node.type):
+ return c_node
+ tree.END_FOR_EACH_ELEMENT_FROM(c_node)
+ return NULL
+
+ @cython.final
+ cdef xmlNode* _nextNodeMatchTag(self, xmlNode* c_node) noexcept:
+ tree.BEGIN_FOR_EACH_ELEMENT_FROM(self._top_node._c_node, c_node, 0)
+ if self._matcher.matches(c_node):
+ return c_node
+ tree.END_FOR_EACH_ELEMENT_FROM(c_node)
+ return NULL
+
+
+cdef class ElementTextIterator:
+ """ElementTextIterator(self, element, tag=None, with_tail=True)
+ Iterates over the text content of a subtree.
+
+ You can pass the ``tag`` keyword argument to restrict text content to a
+ specific tag name.
+
+ You can set the ``with_tail`` keyword argument to ``False`` to skip over
+ tail text (e.g. if you know that it's only whitespace from pretty-printing).
+ """
+ cdef object _events
+ cdef _Element _start_element
+ def __cinit__(self, _Element element not None, tag=None, *, bint with_tail=True):
+ _assertValidNode(element)
+ if with_tail:
+ events = ("start", "comment", "pi", "end")
+ else:
+ events = ("start",)
+ self._start_element = element
+ self._events = iterwalk(element, events=events, tag=tag)
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ cdef _Element element
+ result = None
+ while result is None:
+ event, element = next(self._events) # raises StopIteration
+ if event == "start":
+ result = element.text
+ elif element is not self._start_element:
+ result = element.tail
+ return result
+
+
+cdef xmlNode* _createElement(xmlDoc* c_doc, object name_utf) except NULL:
+ cdef xmlNode* c_node
+ c_node = tree.xmlNewDocNode(c_doc, NULL, _xcstr(name_utf), NULL)
+ return c_node
+
+cdef xmlNode* _createComment(xmlDoc* c_doc, const_xmlChar* text) noexcept:
+ cdef xmlNode* c_node
+ c_node = tree.xmlNewDocComment(c_doc, text)
+ return c_node
+
+cdef xmlNode* _createPI(xmlDoc* c_doc, const_xmlChar* target, const_xmlChar* text) noexcept:
+ cdef xmlNode* c_node
+ c_node = tree.xmlNewDocPI(c_doc, target, text)
+ return c_node
+
+cdef xmlNode* _createEntity(xmlDoc* c_doc, const_xmlChar* name) noexcept:
+ cdef xmlNode* c_node
+ c_node = tree.xmlNewReference(c_doc, name)
+ return c_node
+
+# module-level API for ElementTree
+
+from abc import ABC
+
+class Element(ABC):
+ """Element(_tag, attrib=None, nsmap=None, **_extra)
+
+ Element factory, as a class.
+
+ An instance of this class is an object implementing the
+ Element interface.
+
+ >>> element = Element("test")
+ >>> type(element)
+
+ >>> isinstance(element, Element)
+ True
+ >>> issubclass(_Element, Element)
+ True
+
+ Also look at the `_Element.makeelement()` and
+ `_BaseParser.makeelement()` methods, which provide a faster way to
+ create an Element within a specific document or parser context.
+ """
+ def __new__(cls, _tag, attrib=None, nsmap=None, **_extra):
+ return _makeElement(_tag, NULL, None, None, None, None,
+ attrib, nsmap, _extra)
+
+# Register _Element as a virtual subclass of Element
+Element.register(_Element)
+
+
+def Comment(text=None):
+ """Comment(text=None)
+
+ Comment element factory. This factory function creates a special element that will
+ be serialized as an XML comment.
+ """
+ cdef _Document doc
+ cdef xmlNode* c_node
+ cdef xmlDoc* c_doc
+
+ if text is None:
+ text = b''
+ else:
+ text = _utf8(text)
+ if b'--' in text or text.endswith(b'-'):
+ raise ValueError("Comment may not contain '--' or end with '-'")
+
+ c_doc = _newXMLDoc()
+ doc = _documentFactory(c_doc, None)
+ c_node = _createComment(c_doc, _xcstr(text))
+ tree.xmlAddChild(c_doc, c_node)
+ return _elementFactory(doc, c_node)
+
+
+def ProcessingInstruction(target, text=None):
+ """ProcessingInstruction(target, text=None)
+
+ ProcessingInstruction element factory. This factory function creates a
+ special element that will be serialized as an XML processing instruction.
+ """
+ cdef _Document doc
+ cdef xmlNode* c_node
+ cdef xmlDoc* c_doc
+
+ target = _utf8(target)
+ _tagValidOrRaise(target)
+ if target.lower() == b'xml':
+ raise ValueError, f"Invalid PI name '{target}'"
+
+ if text is None:
+ text = b''
+ else:
+ text = _utf8(text)
+ if b'?>' in text:
+ raise ValueError, "PI text must not contain '?>'"
+
+ c_doc = _newXMLDoc()
+ doc = _documentFactory(c_doc, None)
+ c_node = _createPI(c_doc, _xcstr(target), _xcstr(text))
+ tree.xmlAddChild(c_doc, c_node)
+ return _elementFactory(doc, c_node)
+
+PI = ProcessingInstruction
+
+
+cdef class CDATA:
+ """CDATA(data)
+
+ CDATA factory. This factory creates an opaque data object that
+ can be used to set Element text. The usual way to use it is::
+
+ >>> el = Element('content')
+ >>> el.text = CDATA('a string')
+
+ >>> print(el.text)
+ a string
+ >>> print(tostring(el, encoding="unicode"))
+
+ """
+ cdef bytes _utf8_data
+ def __cinit__(self, data):
+ self._utf8_data = _utf8(data)
+
+
+def Entity(name):
+ """Entity(name)
+
+ Entity factory. This factory function creates a special element
+ that will be serialized as an XML entity reference or character
+ reference. Note, however, that entities will not be automatically
+ declared in the document. A document that uses entity references
+ requires a DTD to define the entities.
+ """
+ cdef _Document doc
+ cdef xmlNode* c_node
+ cdef xmlDoc* c_doc
+ name_utf = _utf8(name)
+ c_name = _xcstr(name_utf)
+ if c_name[0] == c'#':
+ if not _characterReferenceIsValid(c_name + 1):
+ raise ValueError, f"Invalid character reference: '{name}'"
+ elif not _xmlNameIsValid(c_name):
+ raise ValueError, f"Invalid entity reference: '{name}'"
+ c_doc = _newXMLDoc()
+ doc = _documentFactory(c_doc, None)
+ c_node = _createEntity(c_doc, c_name)
+ tree.xmlAddChild(c_doc, c_node)
+ return _elementFactory(doc, c_node)
+
+
+def SubElement(_Element _parent not None, _tag,
+ attrib=None, nsmap=None, **_extra):
+ """SubElement(_parent, _tag, attrib=None, nsmap=None, **_extra)
+
+ Subelement factory. This function creates an element instance, and
+ appends it to an existing element.
+ """
+ return _makeSubElement(_parent, _tag, None, None, attrib, nsmap, _extra)
+
+from typing import Generic, TypeVar
+
+T = TypeVar("T")
+
+class ElementTree(ABC, Generic[T]):
+ def __new__(cls, _Element element=None, *, file=None, _BaseParser parser=None):
+ """ElementTree(element=None, file=None, parser=None)
+
+ ElementTree wrapper class.
+ """
+ cdef xmlNode* c_next
+ cdef xmlNode* c_node
+ cdef xmlNode* c_node_copy
+ cdef xmlDoc* c_doc
+ cdef _ElementTree etree
+ cdef _Document doc
+
+ if element is not None:
+ doc = element._doc
+ elif file is not None:
+ try:
+ doc = _parseDocument(file, parser, None)
+ except _TargetParserResult as result_container:
+ return result_container.result
+ else:
+ c_doc = _newXMLDoc()
+ doc = _documentFactory(c_doc, parser)
+
+ return _elementTreeFactory(doc, element)
+
+# Register _ElementTree as a virtual subclass of ElementTree
+ElementTree.register(_ElementTree)
+
+# Remove "ABC" and typing helpers from module dict
+del ABC, Generic, TypeVar, T
+
+def HTML(text, _BaseParser parser=None, *, base_url=None):
+ """HTML(text, parser=None, base_url=None)
+
+ Parses an HTML document from a string constant. Returns the root
+ node (or the result returned by a parser target). This function
+ can be used to embed "HTML literals" in Python code.
+
+ To override the parser with a different ``HTMLParser`` you can pass it to
+ the ``parser`` keyword argument.
+
+ The ``base_url`` keyword argument allows to set the original base URL of
+ the document to support relative Paths when looking up external entities
+ (DTD, XInclude, ...).
+ """
+ cdef _Document doc
+ if parser is None:
+ parser = __GLOBAL_PARSER_CONTEXT.getDefaultParser()
+ if not isinstance(parser, HTMLParser):
+ parser = __DEFAULT_HTML_PARSER
+ try:
+ doc = _parseMemoryDocument(text, base_url, parser)
+ return doc.getroot()
+ except _TargetParserResult as result_container:
+ return result_container.result
+
+
+def XML(text, _BaseParser parser=None, *, base_url=None):
+ """XML(text, parser=None, base_url=None)
+
+ Parses an XML document or fragment from a string constant.
+ Returns the root node (or the result returned by a parser target).
+ This function can be used to embed "XML literals" in Python code,
+ like in
+
+ >>> root = XML(" ")
+ >>> print(root.tag)
+ root
+
+ To override the parser with a different ``XMLParser`` you can pass it to
+ the ``parser`` keyword argument.
+
+ The ``base_url`` keyword argument allows to set the original base URL of
+ the document to support relative Paths when looking up external entities
+ (DTD, XInclude, ...).
+ """
+ cdef _Document doc
+ if parser is None:
+ parser = __GLOBAL_PARSER_CONTEXT.getDefaultParser()
+ if not isinstance(parser, XMLParser):
+ parser = __DEFAULT_XML_PARSER
+ try:
+ doc = _parseMemoryDocument(text, base_url, parser)
+ return doc.getroot()
+ except _TargetParserResult as result_container:
+ return result_container.result
+
+
+def fromstring(text, _BaseParser parser=None, *, base_url=None):
+ """fromstring(text, parser=None, base_url=None)
+
+ Parses an XML document or fragment from a string. Returns the
+ root node (or the result returned by a parser target).
+
+ To override the default parser with a different parser you can pass it to
+ the ``parser`` keyword argument.
+
+ The ``base_url`` keyword argument allows to set the original base URL of
+ the document to support relative Paths when looking up external entities
+ (DTD, XInclude, ...).
+ """
+ cdef _Document doc
+ try:
+ doc = _parseMemoryDocument(text, base_url, parser)
+ return doc.getroot()
+ except _TargetParserResult as result_container:
+ return result_container.result
+
+
+def fromstringlist(strings, _BaseParser parser=None):
+ """fromstringlist(strings, parser=None)
+
+ Parses an XML document from a sequence of strings. Returns the
+ root node (or the result returned by a parser target).
+
+ To override the default parser with a different parser you can pass it to
+ the ``parser`` keyword argument.
+ """
+ cdef _Document doc
+ if isinstance(strings, (bytes, unicode)):
+ raise ValueError("passing a single string into fromstringlist() is not"
+ " efficient, use fromstring() instead")
+ if parser is None:
+ parser = __GLOBAL_PARSER_CONTEXT.getDefaultParser()
+ feed = parser.feed
+ for data in strings:
+ feed(data)
+ return parser.close()
+
+
+def iselement(element):
+ """iselement(element)
+
+ Checks if an object appears to be a valid element object.
+ """
+ return isinstance(element, _Element) and (<_Element>element)._c_node is not NULL
+
+
+def indent(tree, space=" ", *, Py_ssize_t level=0):
+ """indent(tree, space=" ", level=0)
+
+ Indent an XML document by inserting newlines and indentation space
+ after elements.
+
+ *tree* is the ElementTree or Element to modify. The (root) element
+ itself will not be changed, but the tail text of all elements in its
+ subtree will be adapted.
+
+ *space* is the whitespace to insert for each indentation level, two
+ space characters by default.
+
+ *level* is the initial indentation level. Setting this to a higher
+ value than 0 can be used for indenting subtrees that are more deeply
+ nested inside of a document.
+ """
+ root = _rootNodeOrRaise(tree)
+ if level < 0:
+ raise ValueError(f"Initial indentation level must be >= 0, got {level}")
+ if _hasChild(root._c_node):
+ space = _utf8(space)
+ indent = b"\n" + level * space
+ _indent_children(root._c_node, 1, space, [indent, indent + space])
+
+
+cdef int _indent_children(xmlNode* c_node, Py_ssize_t level, bytes one_space, list indentations) except -1:
+ # Reuse indentation strings for speed.
+ if len(indentations) <= level:
+ indentations.append(indentations[-1] + one_space)
+
+ # Start a new indentation level for the first child.
+ child_indentation = indentations[level]
+ if not _hasNonWhitespaceText(c_node):
+ _setNodeText(c_node, child_indentation)
+
+ # Recursively indent all children.
+ cdef xmlNode* c_child = _findChildForwards(c_node, 0)
+ while c_child is not NULL:
+ if _hasChild(c_child):
+ _indent_children(c_child, level+1, one_space, indentations)
+ c_next_child = _nextElement(c_child)
+ if not _hasNonWhitespaceTail(c_child):
+ if c_next_child is NULL:
+ # Dedent after the last child.
+ child_indentation = indentations[level-1]
+ _setTailText(c_child, child_indentation)
+ c_child = c_next_child
+ return 0
+
+
+def dump(_Element elem not None, *, bint pretty_print=True, bint with_tail=True):
+ """dump(elem, pretty_print=True, with_tail=True)
+
+ Writes an element tree or element structure to sys.stdout. This function
+ should be used for debugging only.
+ """
+ xml = tostring(elem, pretty_print=pretty_print, with_tail=with_tail, encoding='unicode')
+ if not pretty_print:
+ xml += '\n'
+ sys.stdout.write(xml)
+
+
+def tostring(element_or_tree, *, encoding=None, method="xml",
+ xml_declaration=None, bint pretty_print=False, bint with_tail=True,
+ standalone=None, doctype=None,
+ # method='c14n'
+ bint exclusive=False, inclusive_ns_prefixes=None,
+ # method='c14n2'
+ bint with_comments=True, bint strip_text=False,
+ ):
+ """tostring(element_or_tree, encoding=None, method="xml",
+ xml_declaration=None, pretty_print=False, with_tail=True,
+ standalone=None, doctype=None,
+ exclusive=False, inclusive_ns_prefixes=None,
+ with_comments=True, strip_text=False,
+ )
+
+ Serialize an element to an encoded string representation of its XML
+ tree.
+
+ Defaults to ASCII encoding without XML declaration. This
+ behaviour can be configured with the keyword arguments 'encoding'
+ (string) and 'xml_declaration' (bool). Note that changing the
+ encoding to a non UTF-8 compatible encoding will enable a
+ declaration by default.
+
+ You can also serialise to a Unicode string without declaration by
+ passing the name ``'unicode'`` as encoding (or the ``str`` function
+ in Py3 or ``unicode`` in Py2). This changes the return value from
+ a byte string to an unencoded unicode string.
+
+ The keyword argument 'pretty_print' (bool) enables formatted XML.
+
+ The keyword argument 'method' selects the output method: 'xml',
+ 'html', plain 'text' (text content without tags), 'c14n' or 'c14n2'.
+ Default is 'xml'.
+
+ With ``method="c14n"`` (C14N version 1), the options ``exclusive``,
+ ``with_comments`` and ``inclusive_ns_prefixes`` request exclusive
+ C14N, include comments, and list the inclusive prefixes respectively.
+
+ With ``method="c14n2"`` (C14N version 2), the ``with_comments`` and
+ ``strip_text`` options control the output of comments and text space
+ according to C14N 2.0.
+
+ Passing a boolean value to the ``standalone`` option will output
+ an XML declaration with the corresponding ``standalone`` flag.
+
+ The ``doctype`` option allows passing in a plain string that will
+ be serialised before the XML tree. Note that passing in non
+ well-formed content here will make the XML output non well-formed.
+ Also, an existing doctype in the document tree will not be removed
+ when serialising an ElementTree instance.
+
+ You can prevent the tail text of the element from being serialised
+ by passing the boolean ``with_tail`` option. This has no impact
+ on the tail text of children, which will always be serialised.
+ """
+ cdef bint write_declaration
+ cdef int is_standalone
+ # C14N serialisation
+ if method in ('c14n', 'c14n2'):
+ if encoding is not None:
+ raise ValueError("Cannot specify encoding with C14N")
+ if xml_declaration:
+ raise ValueError("Cannot enable XML declaration in C14N")
+ if method == 'c14n':
+ return _tostringC14N(element_or_tree, exclusive, with_comments, inclusive_ns_prefixes)
+ else:
+ out = BytesIO()
+ target = C14NWriterTarget(
+ utf8_writer(out).write,
+ with_comments=with_comments, strip_text=strip_text)
+ _tree_to_target(element_or_tree, target)
+ return out.getvalue()
+ if not with_comments:
+ raise ValueError("Can only discard comments in C14N serialisation")
+ if strip_text:
+ raise ValueError("Can only strip text in C14N 2.0 serialisation")
+ if encoding is unicode or (encoding is not None and encoding.lower() == 'unicode'):
+ if xml_declaration:
+ raise ValueError, \
+ "Serialisation to unicode must not request an XML declaration"
+ write_declaration = 0
+ encoding = unicode
+ elif xml_declaration is None:
+ # by default, write an XML declaration only for non-standard encodings
+ write_declaration = encoding is not None and encoding.upper() not in \
+ ('ASCII', 'UTF-8', 'UTF8', 'US-ASCII')
+ else:
+ write_declaration = xml_declaration
+ if encoding is None:
+ encoding = 'ASCII'
+ if standalone is None:
+ is_standalone = -1
+ elif standalone:
+ write_declaration = 1
+ is_standalone = 1
+ else:
+ write_declaration = 1
+ is_standalone = 0
+
+ if isinstance(element_or_tree, _Element):
+ return _tostring(<_Element>element_or_tree, encoding, doctype, method,
+ write_declaration, 0, pretty_print, with_tail,
+ is_standalone)
+ elif isinstance(element_or_tree, _ElementTree):
+ return _tostring((<_ElementTree>element_or_tree)._context_node,
+ encoding, doctype, method, write_declaration, 1,
+ pretty_print, with_tail, is_standalone)
+ else:
+ raise TypeError, f"Type '{python._fqtypename(element_or_tree).decode('utf8')}' cannot be serialized."
+
+
+
+def tostringlist(element_or_tree, *args, **kwargs):
+ """tostringlist(element_or_tree, *args, **kwargs)
+
+ Serialize an element to an encoded string representation of its XML
+ tree, stored in a list of partial strings.
+
+ This is purely for ElementTree 1.3 compatibility. The result is a
+ single string wrapped in a list.
+ """
+ return [tostring(element_or_tree, *args, **kwargs)]
+
+
+def tounicode(element_or_tree, *, method="xml", bint pretty_print=False,
+ bint with_tail=True, doctype=None):
+ """tounicode(element_or_tree, method="xml", pretty_print=False,
+ with_tail=True, doctype=None)
+
+ Serialize an element to the Python unicode representation of its XML
+ tree.
+
+ :deprecated: use ``tostring(el, encoding='unicode')`` instead.
+
+ Note that the result does not carry an XML encoding declaration and is
+ therefore not necessarily suited for serialization to byte streams without
+ further treatment.
+
+ The boolean keyword argument 'pretty_print' enables formatted XML.
+
+ The keyword argument 'method' selects the output method: 'xml',
+ 'html' or plain 'text'.
+
+ You can prevent the tail text of the element from being serialised
+ by passing the boolean ``with_tail`` option. This has no impact
+ on the tail text of children, which will always be serialised.
+ """
+ if isinstance(element_or_tree, _Element):
+ return _tostring(<_Element>element_or_tree, unicode, doctype, method,
+ 0, 0, pretty_print, with_tail, -1)
+ elif isinstance(element_or_tree, _ElementTree):
+ return _tostring((<_ElementTree>element_or_tree)._context_node,
+ unicode, doctype, method, 0, 1, pretty_print,
+ with_tail, -1)
+ else:
+ raise TypeError, f"Type '{type(element_or_tree)}' cannot be serialized."
+
+
+def parse(source, _BaseParser parser=None, *, base_url=None):
+ """parse(source, parser=None, base_url=None)
+
+ Return an ElementTree object loaded with source elements. If no parser
+ is provided as second argument, the default parser is used.
+
+ The ``source`` can be any of the following:
+
+ - a file name/path
+ - a file object
+ - a file-like object
+ - a URL using the HTTP or FTP protocol
+
+ To parse from a string, use the ``fromstring()`` function instead.
+
+ Note that it is generally faster to parse from a file path or URL
+ than from an open file object or file-like object. Transparent
+ decompression from gzip compressed sources is supported (unless
+ explicitly disabled in libxml2).
+
+ The ``base_url`` keyword allows setting a URL for the document
+ when parsing from a file-like object. This is needed when looking
+ up external entities (DTD, XInclude, ...) with relative paths.
+ """
+ cdef _Document doc
+ try:
+ doc = _parseDocument(source, parser, base_url)
+ return _elementTreeFactory(doc, None)
+ except _TargetParserResult as result_container:
+ return result_container.result
+
+
+def adopt_external_document(capsule, _BaseParser parser=None):
+ """adopt_external_document(capsule, parser=None)
+
+ Unpack a libxml2 document pointer from a PyCapsule and wrap it in an
+ lxml ElementTree object.
+
+ This allows external libraries to build XML/HTML trees using libxml2
+ and then pass them efficiently into lxml for further processing.
+
+ If a ``parser`` is provided, it will be used for configuring the
+ lxml document. No parsing will be done.
+
+ The capsule must have the name ``"libxml2:xmlDoc"`` and its pointer
+ value must reference a correct libxml2 document of type ``xmlDoc*``.
+ The creator of the capsule must take care to correctly clean up the
+ document using an appropriate capsule destructor. By default, the
+ libxml2 document will be copied to let lxml safely own the memory
+ of the internal tree that it uses.
+
+ If the capsule context is non-NULL, it must point to a C string that
+ can be compared using ``strcmp()``. If the context string equals
+ ``"destructor:xmlFreeDoc"``, the libxml2 document will not be copied
+ but the capsule invalidated instead by clearing its destructor and
+ name. That way, lxml takes ownership of the libxml2 document in memory
+ without creating a copy first, and the capsule destructor will not be
+ called. The document will then eventually be cleaned up by lxml using
+ the libxml2 API function ``xmlFreeDoc()`` once it is no longer used.
+
+ If no copy is made, later modifications of the tree outside of lxml
+ should not be attempted after transferring the ownership.
+ """
+ cdef xmlDoc* c_doc
+ cdef bint is_owned = False
+ c_doc = python.lxml_unpack_xmldoc_capsule(capsule, &is_owned)
+ doc = _adoptForeignDoc(c_doc, parser, is_owned)
+ return _elementTreeFactory(doc, None)
+
+
+################################################################################
+# Include submodules
+
+include "readonlytree.pxi" # Read-only implementation of Element proxies
+include "classlookup.pxi" # Element class lookup mechanisms
+include "nsclasses.pxi" # Namespace implementation and registry
+include "docloader.pxi" # Support for custom document loaders
+include "parser.pxi" # XML and HTML parsers
+include "saxparser.pxi" # SAX-like Parser interface and tree builder
+include "parsertarget.pxi" # ET Parser target
+include "serializer.pxi" # XML output functions
+include "iterparse.pxi" # incremental XML parsing
+include "xmlid.pxi" # XMLID and IDDict
+include "xinclude.pxi" # XInclude
+include "cleanup.pxi" # Cleanup and recursive element removal functions
+
+
+################################################################################
+# Include submodules for XPath and XSLT
+
+include "extensions.pxi" # XPath/XSLT extension functions
+include "xpath.pxi" # XPath evaluation
+include "xslt.pxi" # XSL transformations
+include "xsltext.pxi" # XSL extension elements
+
+
+################################################################################
+# Validation
+
+cdef class DocumentInvalid(LxmlError):
+ """Validation error.
+
+ Raised by all document validators when their ``assertValid(tree)``
+ method fails.
+ """
+
+
+cdef class _Validator:
+ "Base class for XML validators."
+ cdef _ErrorLog _error_log
+ def __cinit__(self):
+ self._error_log = _ErrorLog()
+
+ def validate(self, etree):
+ """validate(self, etree)
+
+ Validate the document using this schema.
+
+ Returns true if document is valid, false if not.
+ """
+ return self(etree)
+
+ def assertValid(self, etree):
+ """assertValid(self, etree)
+
+ Raises `DocumentInvalid` if the document does not comply with the schema.
+ """
+ if not self(etree):
+ raise DocumentInvalid(self._error_log._buildExceptionMessage(
+ "Document does not comply with schema"),
+ self._error_log)
+
+ def assert_(self, etree):
+ """assert_(self, etree)
+
+ Raises `AssertionError` if the document does not comply with the schema.
+ """
+ if not self(etree):
+ raise AssertionError, self._error_log._buildExceptionMessage(
+ "Document does not comply with schema")
+
+ cpdef _append_log_message(self, int domain, int type, int level, int line,
+ message, filename):
+ self._error_log._receiveGeneric(domain, type, level, line, message,
+ filename)
+
+ cpdef _clear_error_log(self):
+ self._error_log.clear()
+
+ @property
+ def error_log(self):
+ """The log of validation errors and warnings."""
+ assert self._error_log is not None, "XPath evaluator not initialised"
+ return self._error_log.copy()
+
+include "dtd.pxi" # DTD
+include "relaxng.pxi" # RelaxNG
+include "xmlschema.pxi" # XMLSchema
+include "schematron.pxi" # Schematron (requires libxml2 2.6.21+)
+
+################################################################################
+# Public C API
+
+include "public-api.pxi"
+
+################################################################################
+# Other stuff
+
+include "debug.pxi"
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/etree_api.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/etree_api.h
new file mode 100644
index 0000000000000000000000000000000000000000..bbbb86b5ead5938371cd1d6dd966889a18a57dec
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/etree_api.h
@@ -0,0 +1,204 @@
+/* Generated by Cython 3.1.4 */
+
+#ifndef __PYX_HAVE_API__lxml__etree
+#define __PYX_HAVE_API__lxml__etree
+#ifdef __MINGW64__
+#define MS_WIN64
+#endif
+#include "Python.h"
+#include "etree.h"
+
+static struct LxmlElement *(*__pyx_api_f_4lxml_5etree_deepcopyNodeToDocument)(struct LxmlDocument *, xmlNode *) = 0;
+#define deepcopyNodeToDocument __pyx_api_f_4lxml_5etree_deepcopyNodeToDocument
+static struct LxmlElementTree *(*__pyx_api_f_4lxml_5etree_elementTreeFactory)(struct LxmlElement *) = 0;
+#define elementTreeFactory __pyx_api_f_4lxml_5etree_elementTreeFactory
+static struct LxmlElementTree *(*__pyx_api_f_4lxml_5etree_newElementTree)(struct LxmlElement *, PyObject *) = 0;
+#define newElementTree __pyx_api_f_4lxml_5etree_newElementTree
+static struct LxmlElementTree *(*__pyx_api_f_4lxml_5etree_adoptExternalDocument)(xmlDoc *, PyObject *, int) = 0;
+#define adoptExternalDocument __pyx_api_f_4lxml_5etree_adoptExternalDocument
+static struct LxmlElement *(*__pyx_api_f_4lxml_5etree_elementFactory)(struct LxmlDocument *, xmlNode *) = 0;
+#define elementFactory __pyx_api_f_4lxml_5etree_elementFactory
+static struct LxmlElement *(*__pyx_api_f_4lxml_5etree_makeElement)(PyObject *, struct LxmlDocument *, PyObject *, PyObject *, PyObject *, PyObject *, PyObject *) = 0;
+#define makeElement __pyx_api_f_4lxml_5etree_makeElement
+static struct LxmlElement *(*__pyx_api_f_4lxml_5etree_makeSubElement)(struct LxmlElement *, PyObject *, PyObject *, PyObject *, PyObject *, PyObject *) = 0;
+#define makeSubElement __pyx_api_f_4lxml_5etree_makeSubElement
+static void (*__pyx_api_f_4lxml_5etree_setElementClassLookupFunction)(_element_class_lookup_function, PyObject *) = 0;
+#define setElementClassLookupFunction __pyx_api_f_4lxml_5etree_setElementClassLookupFunction
+static PyObject *(*__pyx_api_f_4lxml_5etree_lookupDefaultElementClass)(PyObject *, PyObject *, xmlNode *) = 0;
+#define lookupDefaultElementClass __pyx_api_f_4lxml_5etree_lookupDefaultElementClass
+static PyObject *(*__pyx_api_f_4lxml_5etree_lookupNamespaceElementClass)(PyObject *, PyObject *, xmlNode *) = 0;
+#define lookupNamespaceElementClass __pyx_api_f_4lxml_5etree_lookupNamespaceElementClass
+static PyObject *(*__pyx_api_f_4lxml_5etree_callLookupFallback)(struct LxmlFallbackElementClassLookup *, struct LxmlDocument *, xmlNode *) = 0;
+#define callLookupFallback __pyx_api_f_4lxml_5etree_callLookupFallback
+static int (*__pyx_api_f_4lxml_5etree_tagMatches)(xmlNode *, const xmlChar *, const xmlChar *) = 0;
+#define tagMatches __pyx_api_f_4lxml_5etree_tagMatches
+static struct LxmlDocument *(*__pyx_api_f_4lxml_5etree_documentOrRaise)(PyObject *) = 0;
+#define documentOrRaise __pyx_api_f_4lxml_5etree_documentOrRaise
+static struct LxmlElement *(*__pyx_api_f_4lxml_5etree_rootNodeOrRaise)(PyObject *) = 0;
+#define rootNodeOrRaise __pyx_api_f_4lxml_5etree_rootNodeOrRaise
+static int (*__pyx_api_f_4lxml_5etree_hasText)(xmlNode *) = 0;
+#define hasText __pyx_api_f_4lxml_5etree_hasText
+static int (*__pyx_api_f_4lxml_5etree_hasTail)(xmlNode *) = 0;
+#define hasTail __pyx_api_f_4lxml_5etree_hasTail
+static PyObject *(*__pyx_api_f_4lxml_5etree_textOf)(xmlNode *) = 0;
+#define textOf __pyx_api_f_4lxml_5etree_textOf
+static PyObject *(*__pyx_api_f_4lxml_5etree_tailOf)(xmlNode *) = 0;
+#define tailOf __pyx_api_f_4lxml_5etree_tailOf
+static int (*__pyx_api_f_4lxml_5etree_setNodeText)(xmlNode *, PyObject *) = 0;
+#define setNodeText __pyx_api_f_4lxml_5etree_setNodeText
+static int (*__pyx_api_f_4lxml_5etree_setTailText)(xmlNode *, PyObject *) = 0;
+#define setTailText __pyx_api_f_4lxml_5etree_setTailText
+static PyObject *(*__pyx_api_f_4lxml_5etree_attributeValue)(xmlNode *, xmlAttr *) = 0;
+#define attributeValue __pyx_api_f_4lxml_5etree_attributeValue
+static PyObject *(*__pyx_api_f_4lxml_5etree_attributeValueFromNsName)(xmlNode *, const xmlChar *, const xmlChar *) = 0;
+#define attributeValueFromNsName __pyx_api_f_4lxml_5etree_attributeValueFromNsName
+static PyObject *(*__pyx_api_f_4lxml_5etree_getAttributeValue)(struct LxmlElement *, PyObject *, PyObject *) = 0;
+#define getAttributeValue __pyx_api_f_4lxml_5etree_getAttributeValue
+static PyObject *(*__pyx_api_f_4lxml_5etree_iterattributes)(struct LxmlElement *, int) = 0;
+#define iterattributes __pyx_api_f_4lxml_5etree_iterattributes
+static PyObject *(*__pyx_api_f_4lxml_5etree_collectAttributes)(xmlNode *, int) = 0;
+#define collectAttributes __pyx_api_f_4lxml_5etree_collectAttributes
+static int (*__pyx_api_f_4lxml_5etree_setAttributeValue)(struct LxmlElement *, PyObject *, PyObject *) = 0;
+#define setAttributeValue __pyx_api_f_4lxml_5etree_setAttributeValue
+static int (*__pyx_api_f_4lxml_5etree_delAttribute)(struct LxmlElement *, PyObject *) = 0;
+#define delAttribute __pyx_api_f_4lxml_5etree_delAttribute
+static int (*__pyx_api_f_4lxml_5etree_delAttributeFromNsName)(xmlNode *, const xmlChar *, const xmlChar *) = 0;
+#define delAttributeFromNsName __pyx_api_f_4lxml_5etree_delAttributeFromNsName
+static int (*__pyx_api_f_4lxml_5etree_hasChild)(xmlNode *) = 0;
+#define hasChild __pyx_api_f_4lxml_5etree_hasChild
+static xmlNode *(*__pyx_api_f_4lxml_5etree_findChild)(xmlNode *, Py_ssize_t) = 0;
+#define findChild __pyx_api_f_4lxml_5etree_findChild
+static xmlNode *(*__pyx_api_f_4lxml_5etree_findChildForwards)(xmlNode *, Py_ssize_t) = 0;
+#define findChildForwards __pyx_api_f_4lxml_5etree_findChildForwards
+static xmlNode *(*__pyx_api_f_4lxml_5etree_findChildBackwards)(xmlNode *, Py_ssize_t) = 0;
+#define findChildBackwards __pyx_api_f_4lxml_5etree_findChildBackwards
+static xmlNode *(*__pyx_api_f_4lxml_5etree_nextElement)(xmlNode *) = 0;
+#define nextElement __pyx_api_f_4lxml_5etree_nextElement
+static xmlNode *(*__pyx_api_f_4lxml_5etree_previousElement)(xmlNode *) = 0;
+#define previousElement __pyx_api_f_4lxml_5etree_previousElement
+static void (*__pyx_api_f_4lxml_5etree_appendChild)(struct LxmlElement *, struct LxmlElement *) = 0;
+#define appendChild __pyx_api_f_4lxml_5etree_appendChild
+static int (*__pyx_api_f_4lxml_5etree_appendChildToElement)(struct LxmlElement *, struct LxmlElement *) = 0;
+#define appendChildToElement __pyx_api_f_4lxml_5etree_appendChildToElement
+static PyObject *(*__pyx_api_f_4lxml_5etree_pyunicode)(const xmlChar *) = 0;
+#define pyunicode __pyx_api_f_4lxml_5etree_pyunicode
+static PyObject *(*__pyx_api_f_4lxml_5etree_utf8)(PyObject *) = 0;
+#define utf8 __pyx_api_f_4lxml_5etree_utf8
+static PyObject *(*__pyx_api_f_4lxml_5etree_getNsTag)(PyObject *) = 0;
+#define getNsTag __pyx_api_f_4lxml_5etree_getNsTag
+static PyObject *(*__pyx_api_f_4lxml_5etree_getNsTagWithEmptyNs)(PyObject *) = 0;
+#define getNsTagWithEmptyNs __pyx_api_f_4lxml_5etree_getNsTagWithEmptyNs
+static PyObject *(*__pyx_api_f_4lxml_5etree_namespacedName)(xmlNode *) = 0;
+#define namespacedName __pyx_api_f_4lxml_5etree_namespacedName
+static PyObject *(*__pyx_api_f_4lxml_5etree_namespacedNameFromNsName)(const xmlChar *, const xmlChar *) = 0;
+#define namespacedNameFromNsName __pyx_api_f_4lxml_5etree_namespacedNameFromNsName
+static void (*__pyx_api_f_4lxml_5etree_iteratorStoreNext)(struct LxmlElementIterator *, struct LxmlElement *) = 0;
+#define iteratorStoreNext __pyx_api_f_4lxml_5etree_iteratorStoreNext
+static void (*__pyx_api_f_4lxml_5etree_initTagMatch)(struct LxmlElementTagMatcher *, PyObject *) = 0;
+#define initTagMatch __pyx_api_f_4lxml_5etree_initTagMatch
+static xmlNs *(*__pyx_api_f_4lxml_5etree_findOrBuildNodeNsPrefix)(struct LxmlDocument *, xmlNode *, const xmlChar *, const xmlChar *) = 0;
+#define findOrBuildNodeNsPrefix __pyx_api_f_4lxml_5etree_findOrBuildNodeNsPrefix
+static int __Pyx_ImportFunction_3_1_4(PyObject *module, const char *funcname, void (**f)(void), const char *sig);
+
+#ifndef __PYX_HAVE_RT_ImportFunction_3_1_4
+#define __PYX_HAVE_RT_ImportFunction_3_1_4
+static int __Pyx_ImportFunction_3_1_4(PyObject *module, const char *funcname, void (**f)(void), const char *sig) {
+ PyObject *d = 0;
+ PyObject *cobj = 0;
+ union {
+ void (*fp)(void);
+ void *p;
+ } tmp;
+ d = PyObject_GetAttrString(module, "__pyx_capi__");
+ if (!d)
+ goto bad;
+#if (defined(Py_LIMITED_API) && Py_LIMITED_API >= 0x030d0000) || (!defined(Py_LIMITED_API) && PY_VERSION_HEX >= 0x030d0000)
+ PyDict_GetItemStringRef(d, funcname, &cobj);
+#else
+ cobj = PyDict_GetItemString(d, funcname);
+ Py_XINCREF(cobj);
+#endif
+ if (!cobj) {
+ PyErr_Format(PyExc_ImportError,
+ "%.200s does not export expected C function %.200s",
+ PyModule_GetName(module), funcname);
+ goto bad;
+ }
+ if (!PyCapsule_IsValid(cobj, sig)) {
+ PyErr_Format(PyExc_TypeError,
+ "C function %.200s.%.200s has wrong signature (expected %.500s, got %.500s)",
+ PyModule_GetName(module), funcname, sig, PyCapsule_GetName(cobj));
+ goto bad;
+ }
+ tmp.p = PyCapsule_GetPointer(cobj, sig);
+ *f = tmp.fp;
+ if (!(*f))
+ goto bad;
+ Py_DECREF(d);
+ Py_DECREF(cobj);
+ return 0;
+bad:
+ Py_XDECREF(d);
+ Py_XDECREF(cobj);
+ return -1;
+}
+#endif
+
+
+static int import_lxml__etree(void) {
+ PyObject *module = 0;
+ module = PyImport_ImportModule("lxml.etree");
+ if (!module) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "deepcopyNodeToDocument", (void (**)(void))&__pyx_api_f_4lxml_5etree_deepcopyNodeToDocument, "struct LxmlElement *(struct LxmlDocument *, xmlNode *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "elementTreeFactory", (void (**)(void))&__pyx_api_f_4lxml_5etree_elementTreeFactory, "struct LxmlElementTree *(struct LxmlElement *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "newElementTree", (void (**)(void))&__pyx_api_f_4lxml_5etree_newElementTree, "struct LxmlElementTree *(struct LxmlElement *, PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "adoptExternalDocument", (void (**)(void))&__pyx_api_f_4lxml_5etree_adoptExternalDocument, "struct LxmlElementTree *(xmlDoc *, PyObject *, int)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "elementFactory", (void (**)(void))&__pyx_api_f_4lxml_5etree_elementFactory, "struct LxmlElement *(struct LxmlDocument *, xmlNode *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "makeElement", (void (**)(void))&__pyx_api_f_4lxml_5etree_makeElement, "struct LxmlElement *(PyObject *, struct LxmlDocument *, PyObject *, PyObject *, PyObject *, PyObject *, PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "makeSubElement", (void (**)(void))&__pyx_api_f_4lxml_5etree_makeSubElement, "struct LxmlElement *(struct LxmlElement *, PyObject *, PyObject *, PyObject *, PyObject *, PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "setElementClassLookupFunction", (void (**)(void))&__pyx_api_f_4lxml_5etree_setElementClassLookupFunction, "void (_element_class_lookup_function, PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "lookupDefaultElementClass", (void (**)(void))&__pyx_api_f_4lxml_5etree_lookupDefaultElementClass, "PyObject *(PyObject *, PyObject *, xmlNode *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "lookupNamespaceElementClass", (void (**)(void))&__pyx_api_f_4lxml_5etree_lookupNamespaceElementClass, "PyObject *(PyObject *, PyObject *, xmlNode *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "callLookupFallback", (void (**)(void))&__pyx_api_f_4lxml_5etree_callLookupFallback, "PyObject *(struct LxmlFallbackElementClassLookup *, struct LxmlDocument *, xmlNode *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "tagMatches", (void (**)(void))&__pyx_api_f_4lxml_5etree_tagMatches, "int (xmlNode *, const xmlChar *, const xmlChar *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "documentOrRaise", (void (**)(void))&__pyx_api_f_4lxml_5etree_documentOrRaise, "struct LxmlDocument *(PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "rootNodeOrRaise", (void (**)(void))&__pyx_api_f_4lxml_5etree_rootNodeOrRaise, "struct LxmlElement *(PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "hasText", (void (**)(void))&__pyx_api_f_4lxml_5etree_hasText, "int (xmlNode *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "hasTail", (void (**)(void))&__pyx_api_f_4lxml_5etree_hasTail, "int (xmlNode *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "textOf", (void (**)(void))&__pyx_api_f_4lxml_5etree_textOf, "PyObject *(xmlNode *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "tailOf", (void (**)(void))&__pyx_api_f_4lxml_5etree_tailOf, "PyObject *(xmlNode *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "setNodeText", (void (**)(void))&__pyx_api_f_4lxml_5etree_setNodeText, "int (xmlNode *, PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "setTailText", (void (**)(void))&__pyx_api_f_4lxml_5etree_setTailText, "int (xmlNode *, PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "attributeValue", (void (**)(void))&__pyx_api_f_4lxml_5etree_attributeValue, "PyObject *(xmlNode *, xmlAttr *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "attributeValueFromNsName", (void (**)(void))&__pyx_api_f_4lxml_5etree_attributeValueFromNsName, "PyObject *(xmlNode *, const xmlChar *, const xmlChar *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "getAttributeValue", (void (**)(void))&__pyx_api_f_4lxml_5etree_getAttributeValue, "PyObject *(struct LxmlElement *, PyObject *, PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "iterattributes", (void (**)(void))&__pyx_api_f_4lxml_5etree_iterattributes, "PyObject *(struct LxmlElement *, int)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "collectAttributes", (void (**)(void))&__pyx_api_f_4lxml_5etree_collectAttributes, "PyObject *(xmlNode *, int)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "setAttributeValue", (void (**)(void))&__pyx_api_f_4lxml_5etree_setAttributeValue, "int (struct LxmlElement *, PyObject *, PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "delAttribute", (void (**)(void))&__pyx_api_f_4lxml_5etree_delAttribute, "int (struct LxmlElement *, PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "delAttributeFromNsName", (void (**)(void))&__pyx_api_f_4lxml_5etree_delAttributeFromNsName, "int (xmlNode *, const xmlChar *, const xmlChar *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "hasChild", (void (**)(void))&__pyx_api_f_4lxml_5etree_hasChild, "int (xmlNode *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "findChild", (void (**)(void))&__pyx_api_f_4lxml_5etree_findChild, "xmlNode *(xmlNode *, Py_ssize_t)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "findChildForwards", (void (**)(void))&__pyx_api_f_4lxml_5etree_findChildForwards, "xmlNode *(xmlNode *, Py_ssize_t)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "findChildBackwards", (void (**)(void))&__pyx_api_f_4lxml_5etree_findChildBackwards, "xmlNode *(xmlNode *, Py_ssize_t)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "nextElement", (void (**)(void))&__pyx_api_f_4lxml_5etree_nextElement, "xmlNode *(xmlNode *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "previousElement", (void (**)(void))&__pyx_api_f_4lxml_5etree_previousElement, "xmlNode *(xmlNode *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "appendChild", (void (**)(void))&__pyx_api_f_4lxml_5etree_appendChild, "void (struct LxmlElement *, struct LxmlElement *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "appendChildToElement", (void (**)(void))&__pyx_api_f_4lxml_5etree_appendChildToElement, "int (struct LxmlElement *, struct LxmlElement *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "pyunicode", (void (**)(void))&__pyx_api_f_4lxml_5etree_pyunicode, "PyObject *(const xmlChar *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "utf8", (void (**)(void))&__pyx_api_f_4lxml_5etree_utf8, "PyObject *(PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "getNsTag", (void (**)(void))&__pyx_api_f_4lxml_5etree_getNsTag, "PyObject *(PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "getNsTagWithEmptyNs", (void (**)(void))&__pyx_api_f_4lxml_5etree_getNsTagWithEmptyNs, "PyObject *(PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "namespacedName", (void (**)(void))&__pyx_api_f_4lxml_5etree_namespacedName, "PyObject *(xmlNode *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "namespacedNameFromNsName", (void (**)(void))&__pyx_api_f_4lxml_5etree_namespacedNameFromNsName, "PyObject *(const xmlChar *, const xmlChar *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "iteratorStoreNext", (void (**)(void))&__pyx_api_f_4lxml_5etree_iteratorStoreNext, "void (struct LxmlElementIterator *, struct LxmlElement *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "initTagMatch", (void (**)(void))&__pyx_api_f_4lxml_5etree_initTagMatch, "void (struct LxmlElementTagMatcher *, PyObject *)") < 0) goto bad;
+ if (__Pyx_ImportFunction_3_1_4(module, "findOrBuildNodeNsPrefix", (void (**)(void))&__pyx_api_f_4lxml_5etree_findOrBuildNodeNsPrefix, "xmlNs *(struct LxmlDocument *, xmlNode *, const xmlChar *, const xmlChar *)") < 0) goto bad;
+ Py_DECREF(module); module = 0;
+ return 0;
+ bad:
+ Py_XDECREF(module);
+ return -1;
+}
+
+#endif /* !__PYX_HAVE_API__lxml__etree */
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/extensions.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/extensions.pxi
new file mode 100644
index 0000000000000000000000000000000000000000..ab687bec9c1d58c1220fae31bce1712d4751a9f2
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/extensions.pxi
@@ -0,0 +1,830 @@
+# support for extension functions in XPath and XSLT
+
+cdef class XPathError(LxmlError):
+ """Base class of all XPath errors.
+ """
+
+cdef class XPathEvalError(XPathError):
+ """Error during XPath evaluation.
+ """
+
+cdef class XPathFunctionError(XPathEvalError):
+ """Internal error looking up an XPath extension function.
+ """
+
+cdef class XPathResultError(XPathEvalError):
+ """Error handling an XPath result.
+ """
+
+
+# forward declarations
+
+ctypedef int (*_register_function)(void* ctxt, name_utf, ns_uri_utf)
+cdef class _ExsltRegExp
+
+################################################################################
+# Base class for XSLT and XPath evaluation contexts: functions, namespaces, ...
+
+@cython.internal
+cdef class _BaseContext:
+ cdef xpath.xmlXPathContext* _xpathCtxt
+ cdef _Document _doc
+ cdef dict _extensions
+ cdef list _namespaces
+ cdef list _global_namespaces
+ cdef dict _utf_refs
+ cdef dict _function_cache
+ cdef dict _eval_context_dict
+ cdef bint _build_smart_strings
+ # for exception handling and temporary reference keeping:
+ cdef _TempStore _temp_refs
+ cdef set _temp_documents
+ cdef _ExceptionContext _exc
+ cdef _ErrorLog _error_log
+
+ def __init__(self, namespaces, extensions, error_log, enable_regexp,
+ build_smart_strings):
+ cdef _ExsltRegExp _regexp
+ cdef dict new_extensions
+ cdef list ns
+ self._utf_refs = {}
+ self._global_namespaces = []
+ self._function_cache = {}
+ self._eval_context_dict = None
+ self._error_log = error_log
+
+ if extensions is not None:
+ # convert extensions to UTF-8
+ if isinstance(extensions, dict):
+ extensions = (extensions,)
+ # format: [ {(ns, name):function} ] -> {(ns_utf, name_utf):function}
+ new_extensions = {}
+ for extension in extensions:
+ for (ns_uri, name), function in extension.items():
+ if name is None:
+ raise ValueError, "extensions must have non empty names"
+ ns_utf = self._to_utf(ns_uri)
+ name_utf = self._to_utf(name)
+ new_extensions[(ns_utf, name_utf)] = function
+ extensions = new_extensions or None
+
+ if namespaces is not None:
+ if isinstance(namespaces, dict):
+ namespaces = namespaces.items()
+ if namespaces:
+ ns = []
+ for prefix, ns_uri in namespaces:
+ if prefix is None or not prefix:
+ raise TypeError, \
+ "empty namespace prefix is not supported in XPath"
+ if ns_uri is None or not ns_uri:
+ raise TypeError, \
+ "setting default namespace is not supported in XPath"
+ prefix_utf = self._to_utf(prefix)
+ ns_uri_utf = self._to_utf(ns_uri)
+ ns.append( (prefix_utf, ns_uri_utf) )
+ namespaces = ns
+ else:
+ namespaces = None
+
+ self._doc = None
+ self._exc = _ExceptionContext()
+ self._extensions = extensions
+ self._namespaces = namespaces
+ self._temp_refs = _TempStore()
+ self._temp_documents = set()
+ self._build_smart_strings = build_smart_strings
+
+ if enable_regexp:
+ _regexp = _ExsltRegExp()
+ _regexp._register_in_context(self)
+
+ cdef _BaseContext _copy(self):
+ cdef _BaseContext context
+ if self._namespaces is not None:
+ namespaces = self._namespaces[:]
+ else:
+ namespaces = None
+ context = self.__class__(namespaces, None, self._error_log, False,
+ self._build_smart_strings)
+ if self._extensions is not None:
+ context._extensions = self._extensions.copy()
+ return context
+
+ cdef bytes _to_utf(self, s):
+ "Convert to UTF-8 and keep a reference to the encoded string"
+ cdef python.PyObject* dict_result
+ if s is None:
+ return None
+ dict_result = python.PyDict_GetItem(self._utf_refs, s)
+ if dict_result is not NULL:
+ return dict_result
+ utf = _utf8(s)
+ self._utf_refs[s] = utf
+ if python.IS_PYPY:
+ # use C level refs, PyPy refs are not enough!
+ python.Py_INCREF(utf)
+ return utf
+
+ cdef void _set_xpath_context(self, xpath.xmlXPathContext* xpathCtxt) noexcept:
+ self._xpathCtxt = xpathCtxt
+ xpathCtxt.userData = self
+ # Need a cast here because older libxml2 releases do not use 'const' in the functype.
+ xpathCtxt.error = _receiveXPathError
+
+ @cython.final
+ cdef _register_context(self, _Document doc):
+ self._doc = doc
+ self._exc.clear()
+
+ @cython.final
+ cdef _cleanup_context(self):
+ #xpath.xmlXPathRegisteredNsCleanup(self._xpathCtxt)
+ #self.unregisterGlobalNamespaces()
+ if python.IS_PYPY:
+ # clean up double refs in PyPy (see "_to_utf()" method)
+ for ref in self._utf_refs.itervalues():
+ python.Py_DECREF(ref)
+ self._utf_refs.clear()
+ self._eval_context_dict = None
+ self._doc = None
+
+ @cython.final
+ cdef _release_context(self):
+ if self._xpathCtxt is not NULL:
+ self._xpathCtxt.userData = NULL
+ self._xpathCtxt = NULL
+
+ # namespaces (internal UTF-8 methods with leading '_')
+
+ cdef addNamespace(self, prefix, ns_uri):
+ cdef list namespaces
+ if prefix is None:
+ raise TypeError, "empty prefix is not supported in XPath"
+ prefix_utf = self._to_utf(prefix)
+ ns_uri_utf = self._to_utf(ns_uri)
+ new_item = (prefix_utf, ns_uri_utf)
+ if self._namespaces is None:
+ self._namespaces = [new_item]
+ else:
+ namespaces = []
+ for item in self._namespaces:
+ if item[0] == prefix_utf:
+ item = new_item
+ new_item = None
+ namespaces.append(item)
+ if new_item is not None:
+ namespaces.append(new_item)
+ self._namespaces = namespaces
+ if self._xpathCtxt is not NULL:
+ xpath.xmlXPathRegisterNs(
+ self._xpathCtxt, _xcstr(prefix_utf), _xcstr(ns_uri_utf))
+
+ cdef registerNamespace(self, prefix, ns_uri):
+ if prefix is None:
+ raise TypeError, "empty prefix is not supported in XPath"
+ prefix_utf = self._to_utf(prefix)
+ ns_uri_utf = self._to_utf(ns_uri)
+ self._global_namespaces.append(prefix_utf)
+ xpath.xmlXPathRegisterNs(self._xpathCtxt,
+ _xcstr(prefix_utf), _xcstr(ns_uri_utf))
+
+ cdef registerLocalNamespaces(self):
+ if self._namespaces is None:
+ return
+ for prefix_utf, ns_uri_utf in self._namespaces:
+ xpath.xmlXPathRegisterNs(
+ self._xpathCtxt, _xcstr(prefix_utf), _xcstr(ns_uri_utf))
+
+ cdef registerGlobalNamespaces(self):
+ cdef list ns_prefixes = _find_all_extension_prefixes()
+ if python.PyList_GET_SIZE(ns_prefixes) > 0:
+ for prefix_utf, ns_uri_utf in ns_prefixes:
+ self._global_namespaces.append(prefix_utf)
+ xpath.xmlXPathRegisterNs(
+ self._xpathCtxt, _xcstr(prefix_utf), _xcstr(ns_uri_utf))
+
+ cdef unregisterGlobalNamespaces(self):
+ if python.PyList_GET_SIZE(self._global_namespaces) > 0:
+ for prefix_utf in self._global_namespaces:
+ xpath.xmlXPathRegisterNs(self._xpathCtxt,
+ _xcstr(prefix_utf), NULL)
+ del self._global_namespaces[:]
+
+ cdef void _unregisterNamespace(self, prefix_utf) noexcept:
+ xpath.xmlXPathRegisterNs(self._xpathCtxt,
+ _xcstr(prefix_utf), NULL)
+
+ # extension functions
+
+ cdef int _addLocalExtensionFunction(self, ns_utf, name_utf, function) except -1:
+ if self._extensions is None:
+ self._extensions = {}
+ self._extensions[(ns_utf, name_utf)] = function
+ return 0
+
+ cdef registerGlobalFunctions(self, void* ctxt,
+ _register_function reg_func):
+ cdef python.PyObject* dict_result
+ cdef dict d
+ for ns_utf, ns_functions in __FUNCTION_NAMESPACE_REGISTRIES.iteritems():
+ dict_result = python.PyDict_GetItem(
+ self._function_cache, ns_utf)
+ if dict_result is not NULL:
+ d = dict_result
+ else:
+ d = {}
+ self._function_cache[ns_utf] = d
+ for name_utf, function in ns_functions.iteritems():
+ d[name_utf] = function
+ reg_func(ctxt, name_utf, ns_utf)
+
+ cdef registerLocalFunctions(self, void* ctxt,
+ _register_function reg_func):
+ cdef python.PyObject* dict_result
+ cdef dict d
+ if self._extensions is None:
+ return # done
+ last_ns = None
+ d = None
+ for (ns_utf, name_utf), function in self._extensions.iteritems():
+ if ns_utf is not last_ns or d is None:
+ last_ns = ns_utf
+ dict_result = python.PyDict_GetItem(
+ self._function_cache, ns_utf)
+ if dict_result is not NULL:
+ d = dict_result
+ else:
+ d = {}
+ self._function_cache[ns_utf] = d
+ d[name_utf] = function
+ reg_func(ctxt, name_utf, ns_utf)
+
+ cdef unregisterAllFunctions(self, void* ctxt,
+ _register_function unreg_func):
+ for ns_utf, functions in self._function_cache.iteritems():
+ for name_utf in functions:
+ unreg_func(ctxt, name_utf, ns_utf)
+
+ cdef unregisterGlobalFunctions(self, void* ctxt,
+ _register_function unreg_func):
+ for ns_utf, functions in self._function_cache.items():
+ for name_utf in functions:
+ if self._extensions is None or \
+ (ns_utf, name_utf) not in self._extensions:
+ unreg_func(ctxt, name_utf, ns_utf)
+
+ @cython.final
+ cdef _find_cached_function(self, const_xmlChar* c_ns_uri, const_xmlChar* c_name):
+ """Lookup an extension function in the cache and return it.
+
+ Parameters: c_ns_uri may be NULL, c_name must not be NULL
+ """
+ cdef python.PyObject* c_dict
+ cdef python.PyObject* dict_result
+ c_dict = python.PyDict_GetItem(
+ self._function_cache, None if c_ns_uri is NULL else c_ns_uri)
+ if c_dict is not NULL:
+ dict_result = python.PyDict_GetItem(
+ c_dict, c_name)
+ if dict_result is not NULL:
+ return dict_result
+ return None
+
+ # Python access to the XPath context for extension functions
+
+ @property
+ def context_node(self):
+ cdef xmlNode* c_node
+ if self._xpathCtxt is NULL:
+ raise XPathError, \
+ "XPath context is only usable during the evaluation"
+ c_node = self._xpathCtxt.node
+ if c_node is NULL:
+ raise XPathError, "no context node"
+ if c_node.doc != self._xpathCtxt.doc:
+ raise XPathError, \
+ "document-external context nodes are not supported"
+ if self._doc is None:
+ raise XPathError, "document context is missing"
+ return _elementFactory(self._doc, c_node)
+
+ @property
+ def eval_context(self):
+ if self._eval_context_dict is None:
+ self._eval_context_dict = {}
+ return self._eval_context_dict
+
+ # Python reference keeping during XPath function evaluation
+
+ @cython.final
+ cdef _release_temp_refs(self):
+ "Free temporarily referenced objects from this context."
+ self._temp_refs.clear()
+ self._temp_documents.clear()
+
+ @cython.final
+ cdef _hold(self, obj):
+ """A way to temporarily hold references to nodes in the evaluator.
+
+ This is needed because otherwise nodes created in XPath extension
+ functions would be reference counted too soon, during the XPath
+ evaluation. This is most important in the case of exceptions.
+ """
+ cdef _Element element
+ if isinstance(obj, _Element):
+ self._temp_refs.add(obj)
+ self._temp_documents.add((<_Element>obj)._doc)
+ return
+ elif _isString(obj) or not python.PySequence_Check(obj):
+ return
+ for o in obj:
+ if isinstance(o, _Element):
+ #print "Holding element:", element._c_node
+ self._temp_refs.add(o)
+ #print "Holding document:", element._doc._c_doc
+ self._temp_documents.add((<_Element>o)._doc)
+
+ @cython.final
+ cdef _Document _findDocumentForNode(self, xmlNode* c_node):
+ """If an XPath expression returns an element from a different
+ document than the current context document, we call this to
+ see if it was possibly created by an extension and is a known
+ document instance.
+ """
+ cdef _Document doc
+ for doc in self._temp_documents:
+ if doc is not None and doc._c_doc is c_node.doc:
+ return doc
+ return None
+
+
+# libxml2 keeps these error messages in a static array in its code
+# and doesn't give us access to them ...
+
+cdef tuple LIBXML2_XPATH_ERROR_MESSAGES = (
+ b"Ok",
+ b"Number encoding",
+ b"Unfinished literal",
+ b"Start of literal",
+ b"Expected $ for variable reference",
+ b"Undefined variable",
+ b"Invalid predicate",
+ b"Invalid expression",
+ b"Missing closing curly brace",
+ b"Unregistered function",
+ b"Invalid operand",
+ b"Invalid type",
+ b"Invalid number of arguments",
+ b"Invalid context size",
+ b"Invalid context position",
+ b"Memory allocation error",
+ b"Syntax error",
+ b"Resource error",
+ b"Sub resource error",
+ b"Undefined namespace prefix",
+ b"Encoding error",
+ b"Char out of XML range",
+ b"Invalid or incomplete context",
+ b"Stack usage error",
+ b"Forbidden variable\n",
+ b"?? Unknown error ??\n",
+)
+
+cdef void _forwardXPathError(void* c_ctxt, const xmlerror.xmlError* c_error) noexcept with gil:
+ cdef xmlerror.xmlError error
+ cdef int xpath_code
+ if c_error.message is not NULL:
+ error.message = c_error.message
+ else:
+ xpath_code = c_error.code - xmlerror.XML_XPATH_EXPRESSION_OK
+ if 0 <= xpath_code < len(LIBXML2_XPATH_ERROR_MESSAGES):
+ error.message = _cstr(LIBXML2_XPATH_ERROR_MESSAGES[xpath_code])
+ else:
+ error.message = b"unknown error"
+ error.domain = c_error.domain
+ error.code = c_error.code
+ error.level = c_error.level
+ error.line = c_error.line
+ error.int2 = c_error.int1 # column
+ error.file = c_error.file
+ error.node = NULL
+
+ (<_BaseContext>c_ctxt)._error_log._receive(&error)
+
+cdef void _receiveXPathError(void* c_context, const xmlerror.xmlError* error) noexcept nogil:
+ if not __DEBUG:
+ return
+ if c_context is NULL:
+ _forwardError(NULL, error)
+ else:
+ _forwardXPathError(c_context, error)
+
+
+def Extension(module, function_mapping=None, *, ns=None):
+ """Extension(module, function_mapping=None, ns=None)
+
+ Build a dictionary of extension functions from the functions
+ defined in a module or the methods of an object.
+
+ As second argument, you can pass an additional mapping of
+ attribute names to XPath function names, or a list of function
+ names that should be taken.
+
+ The ``ns`` keyword argument accepts a namespace URI for the XPath
+ functions.
+ """
+ cdef dict functions = {}
+ if isinstance(function_mapping, dict):
+ for function_name, xpath_name in function_mapping.items():
+ functions[(ns, xpath_name)] = getattr(module, function_name)
+ else:
+ if function_mapping is None:
+ function_mapping = [ name for name in dir(module)
+ if not name.startswith('_') ]
+ for function_name in function_mapping:
+ functions[(ns, function_name)] = getattr(module, function_name)
+ return functions
+
+################################################################################
+# EXSLT regexp implementation
+
+@cython.final
+@cython.internal
+cdef class _ExsltRegExp:
+ cdef dict _compile_map
+ def __cinit__(self):
+ self._compile_map = {}
+
+ cdef _make_string(self, value):
+ if _isString(value):
+ return value
+ elif isinstance(value, list):
+ # node set: take recursive text concatenation of first element
+ if python.PyList_GET_SIZE(value) == 0:
+ return ''
+ firstnode = value[0]
+ if _isString(firstnode):
+ return firstnode
+ elif isinstance(firstnode, _Element):
+ c_text = tree.xmlNodeGetContent((<_Element>firstnode)._c_node)
+ if c_text is NULL:
+ raise MemoryError()
+ try:
+ return funicode(c_text)
+ finally:
+ tree.xmlFree(c_text)
+ else:
+ return unicode(firstnode)
+ else:
+ return unicode(value)
+
+ cdef _compile(self, rexp, ignore_case):
+ cdef python.PyObject* c_result
+ rexp = self._make_string(rexp)
+ key = (rexp, ignore_case)
+ c_result = python.PyDict_GetItem(self._compile_map, key)
+ if c_result is not NULL:
+ return c_result
+ py_flags = re.UNICODE
+ if ignore_case:
+ py_flags = py_flags | re.IGNORECASE
+ rexp_compiled = re.compile(rexp, py_flags)
+ self._compile_map[key] = rexp_compiled
+ return rexp_compiled
+
+ def test(self, ctxt, s, rexp, flags=''):
+ flags = self._make_string(flags)
+ s = self._make_string(s)
+ rexpc = self._compile(rexp, 'i' in flags)
+ if rexpc.search(s) is None:
+ return False
+ else:
+ return True
+
+ def match(self, ctxt, s, rexp, flags=''):
+ cdef list result_list
+ flags = self._make_string(flags)
+ s = self._make_string(s)
+ rexpc = self._compile(rexp, 'i' in flags)
+ if 'g' in flags:
+ results = rexpc.findall(s)
+ if not results:
+ return ()
+ else:
+ result = rexpc.search(s)
+ if not result:
+ return ()
+ results = [ result.group() ]
+ results.extend( result.groups('') )
+ result_list = []
+ root = Element('matches')
+ for s_match in results:
+ if python.PyTuple_CheckExact(s_match):
+ s_match = ''.join(s_match)
+ elem = SubElement(root, 'match')
+ elem.text = s_match
+ result_list.append(elem)
+ return result_list
+
+ def replace(self, ctxt, s, rexp, flags, replacement):
+ replacement = self._make_string(replacement)
+ flags = self._make_string(flags)
+ s = self._make_string(s)
+ rexpc = self._compile(rexp, 'i' in flags)
+ count: object = 0 if 'g' in flags else 1
+ return rexpc.sub(replacement, s, count)
+
+ cdef _register_in_context(self, _BaseContext context):
+ ns = b"http://exslt.org/regular-expressions"
+ context._addLocalExtensionFunction(ns, b"test", self.test)
+ context._addLocalExtensionFunction(ns, b"match", self.match)
+ context._addLocalExtensionFunction(ns, b"replace", self.replace)
+
+
+################################################################################
+# helper functions
+
+cdef xpath.xmlXPathObject* _wrapXPathObject(object obj, _Document doc,
+ _BaseContext context) except NULL:
+ cdef xpath.xmlNodeSet* resultSet
+ cdef _Element fake_node = None
+ cdef xmlNode* c_node
+
+ if isinstance(obj, unicode):
+ obj = _utf8(obj)
+ if isinstance(obj, bytes):
+ # libxml2 copies the string value
+ return xpath.xmlXPathNewCString(_cstr(obj))
+ if isinstance(obj, bool):
+ return xpath.xmlXPathNewBoolean(obj)
+ if python.PyNumber_Check(obj):
+ return xpath.xmlXPathNewFloat(obj)
+ if obj is None:
+ resultSet = xpath.xmlXPathNodeSetCreate(NULL)
+ elif isinstance(obj, _Element):
+ resultSet = xpath.xmlXPathNodeSetCreate((<_Element>obj)._c_node)
+ elif python.PySequence_Check(obj):
+ resultSet = xpath.xmlXPathNodeSetCreate(NULL)
+ try:
+ for value in obj:
+ if isinstance(value, _Element):
+ if context is not None:
+ context._hold(value)
+ xpath.xmlXPathNodeSetAdd(resultSet, (<_Element>value)._c_node)
+ else:
+ if context is None or doc is None:
+ raise XPathResultError, \
+ f"Non-Element values not supported at this point - got {value!r}"
+ # support strings by appending text nodes to an Element
+ if isinstance(value, unicode):
+ value = _utf8(value)
+ if isinstance(value, bytes):
+ if fake_node is None:
+ fake_node = _makeElement("text-root", NULL, doc, None,
+ None, None, None, None, None)
+ context._hold(fake_node)
+ else:
+ # append a comment node to keep the text nodes separate
+ c_node = tree.xmlNewDocComment(doc._c_doc, "")
+ if c_node is NULL:
+ raise MemoryError()
+ tree.xmlAddChild(fake_node._c_node, c_node)
+ context._hold(value)
+ c_node = tree.xmlNewDocText(doc._c_doc, _xcstr(value))
+ if c_node is NULL:
+ raise MemoryError()
+ tree.xmlAddChild(fake_node._c_node, c_node)
+ xpath.xmlXPathNodeSetAdd(resultSet, c_node)
+ else:
+ raise XPathResultError, \
+ f"This is not a supported node-set result: {value!r}"
+ except:
+ xpath.xmlXPathFreeNodeSet(resultSet)
+ raise
+ else:
+ raise XPathResultError, f"Unknown return type: {python._fqtypename(obj).decode('utf8')}"
+ return xpath.xmlXPathWrapNodeSet(resultSet)
+
+cdef object _unwrapXPathObject(xpath.xmlXPathObject* xpathObj,
+ _Document doc, _BaseContext context):
+ if xpathObj.type == xpath.XPATH_UNDEFINED:
+ raise XPathResultError, "Undefined xpath result"
+ elif xpathObj.type == xpath.XPATH_NODESET:
+ return _createNodeSetResult(xpathObj, doc, context)
+ elif xpathObj.type == xpath.XPATH_BOOLEAN:
+ return xpathObj.boolval
+ elif xpathObj.type == xpath.XPATH_NUMBER:
+ return xpathObj.floatval
+ elif xpathObj.type == xpath.XPATH_STRING:
+ stringval = funicode(xpathObj.stringval)
+ if context._build_smart_strings:
+ stringval = _elementStringResultFactory(
+ stringval, None, None, False)
+ return stringval
+ elif xpathObj.type == xpath.XPATH_POINT:
+ raise NotImplementedError, "XPATH_POINT"
+ elif xpathObj.type == xpath.XPATH_RANGE:
+ raise NotImplementedError, "XPATH_RANGE"
+ elif xpathObj.type == xpath.XPATH_LOCATIONSET:
+ raise NotImplementedError, "XPATH_LOCATIONSET"
+ elif xpathObj.type == xpath.XPATH_USERS:
+ raise NotImplementedError, "XPATH_USERS"
+ elif xpathObj.type == xpath.XPATH_XSLT_TREE:
+ return _createNodeSetResult(xpathObj, doc, context)
+ else:
+ raise XPathResultError, f"Unknown xpath result {xpathObj.type}"
+
+cdef object _createNodeSetResult(xpath.xmlXPathObject* xpathObj, _Document doc,
+ _BaseContext context):
+ cdef xmlNode* c_node
+ cdef int i
+ cdef list result
+ result = []
+ if xpathObj.nodesetval is NULL:
+ return result
+ for i in range(xpathObj.nodesetval.nodeNr):
+ c_node = xpathObj.nodesetval.nodeTab[i]
+ _unpackNodeSetEntry(result, c_node, doc, context,
+ xpathObj.type == xpath.XPATH_XSLT_TREE)
+ return result
+
+cdef _unpackNodeSetEntry(list results, xmlNode* c_node, _Document doc,
+ _BaseContext context, bint is_fragment):
+ cdef xmlNode* c_child
+ if _isElement(c_node):
+ if c_node.doc != doc._c_doc and c_node.doc._private is NULL:
+ # XXX: works, but maybe not always the right thing to do?
+ # XPath: only runs when extensions create or copy trees
+ # -> we store Python refs to these, so that is OK
+ # XSLT: can it leak when merging trees from multiple sources?
+ c_node = tree.xmlDocCopyNode(c_node, doc._c_doc, 1)
+ # FIXME: call _instantiateElementFromXPath() instead?
+ results.append(
+ _fakeDocElementFactory(doc, c_node))
+ elif c_node.type == tree.XML_TEXT_NODE or \
+ c_node.type == tree.XML_CDATA_SECTION_NODE or \
+ c_node.type == tree.XML_ATTRIBUTE_NODE:
+ results.append(
+ _buildElementStringResult(doc, c_node, context))
+ elif c_node.type == tree.XML_NAMESPACE_DECL:
+ results.append( (funicodeOrNone((c_node).prefix),
+ funicodeOrNone((c_node).href)) )
+ elif c_node.type == tree.XML_DOCUMENT_NODE or \
+ c_node.type == tree.XML_HTML_DOCUMENT_NODE:
+ # ignored for everything but result tree fragments
+ if is_fragment:
+ c_child = c_node.children
+ while c_child is not NULL:
+ _unpackNodeSetEntry(results, c_child, doc, context, 0)
+ c_child = c_child.next
+ elif c_node.type == tree.XML_XINCLUDE_START or \
+ c_node.type == tree.XML_XINCLUDE_END:
+ pass
+ else:
+ raise NotImplementedError, \
+ f"Not yet implemented result node type: {c_node.type}"
+
+cdef void _freeXPathObject(xpath.xmlXPathObject* xpathObj) noexcept:
+ """Free the XPath object, but *never* free the *content* of node sets.
+ Python dealloc will do that for us.
+ """
+ if xpathObj.nodesetval is not NULL:
+ xpath.xmlXPathFreeNodeSet(xpathObj.nodesetval)
+ xpathObj.nodesetval = NULL
+ xpath.xmlXPathFreeObject(xpathObj)
+
+cdef _Element _instantiateElementFromXPath(xmlNode* c_node, _Document doc,
+ _BaseContext context):
+ # NOTE: this may copy the element - only call this when it can't leak
+ if c_node.doc != doc._c_doc and c_node.doc._private is NULL:
+ # not from the context document and not from a fake document
+ # either => may still be from a known document, e.g. one
+ # created by an extension function
+ node_doc = context._findDocumentForNode(c_node)
+ if node_doc is None:
+ # not from a known document at all! => can only make a
+ # safety copy here
+ c_node = tree.xmlDocCopyNode(c_node, doc._c_doc, 1)
+ else:
+ doc = node_doc
+ return _fakeDocElementFactory(doc, c_node)
+
+################################################################################
+# special str/unicode subclasses
+
+@cython.final
+cdef class _ElementUnicodeResult(unicode):
+ cdef _Element _parent
+ cdef readonly object attrname
+ cdef readonly bint is_tail
+
+ def getparent(self):
+ return self._parent
+
+ @property
+ def is_text(self):
+ return self._parent is not None and not (self.is_tail or self.attrname is not None)
+
+ @property
+ def is_attribute(self):
+ return self.attrname is not None
+
+cdef object _elementStringResultFactory(string_value, _Element parent,
+ attrname, bint is_tail):
+ result = _ElementUnicodeResult(string_value)
+ result._parent = parent
+ result.is_tail = is_tail
+ result.attrname = attrname
+ return result
+
+cdef object _buildElementStringResult(_Document doc, xmlNode* c_node,
+ _BaseContext context):
+ cdef _Element parent = None
+ cdef object attrname = None
+ cdef xmlNode* c_element
+ cdef bint is_tail
+
+ if c_node.type == tree.XML_ATTRIBUTE_NODE:
+ attrname = _namespacedName(c_node)
+ is_tail = 0
+ s = tree.xmlNodeGetContent(c_node)
+ try:
+ value = funicode(s)
+ finally:
+ tree.xmlFree(s)
+ c_element = NULL
+ else:
+ #assert c_node.type == tree.XML_TEXT_NODE or c_node.type == tree.XML_CDATA_SECTION_NODE, "invalid node type"
+ # may be tail text or normal text
+ value = funicode(c_node.content)
+ c_element = _previousElement(c_node)
+ is_tail = c_element is not NULL
+
+ if not context._build_smart_strings:
+ return value
+
+ if c_element is NULL:
+ # non-tail text or attribute text
+ c_element = c_node.parent
+ while c_element is not NULL and not _isElement(c_element):
+ c_element = c_element.parent
+
+ if c_element is not NULL:
+ parent = _instantiateElementFromXPath(c_element, doc, context)
+
+ return _elementStringResultFactory(
+ value, parent, attrname, is_tail)
+
+################################################################################
+# callbacks for XPath/XSLT extension functions
+
+cdef void _extension_function_call(_BaseContext context, function,
+ xpath.xmlXPathParserContext* ctxt, int nargs) noexcept:
+ cdef _Document doc
+ cdef xpath.xmlXPathObject* obj
+ cdef list args
+ cdef int i
+ doc = context._doc
+ try:
+ args = []
+ for i in range(nargs):
+ obj = xpath.valuePop(ctxt)
+ o = _unwrapXPathObject(obj, doc, context)
+ _freeXPathObject(obj)
+ args.append(o)
+ args.reverse()
+
+ res = function(context, *args)
+ # wrap result for XPath consumption
+ obj = _wrapXPathObject(res, doc, context)
+ # prevent Python from deallocating elements handed to libxml2
+ context._hold(res)
+ xpath.valuePush(ctxt, obj)
+ except:
+ xpath.xmlXPathErr(ctxt, xpath.XPATH_EXPR_ERROR)
+ context._exc._store_raised()
+ finally:
+ return # swallow any further exceptions
+
+# lookup the function by name and call it
+
+cdef void _xpath_function_call(xpath.xmlXPathParserContext* ctxt,
+ int nargs) noexcept with gil:
+ cdef _BaseContext context
+ cdef xpath.xmlXPathContext* rctxt = ctxt.context
+ context = <_BaseContext> rctxt.userData
+ try:
+ function = context._find_cached_function(rctxt.functionURI, rctxt.function)
+ if function is not None:
+ _extension_function_call(context, function, ctxt, nargs)
+ else:
+ xpath.xmlXPathErr(ctxt, xpath.XPATH_UNKNOWN_FUNC_ERROR)
+ context._exc._store_exception(XPathFunctionError(
+ f"XPath function '{_namespacedNameFromNsName(rctxt.functionURI, rctxt.function)}' not found"))
+ except:
+ # may not be the right error, but we need to tell libxml2 *something*
+ xpath.xmlXPathErr(ctxt, xpath.XPATH_UNKNOWN_FUNC_ERROR)
+ context._exc._store_raised()
+ finally:
+ return # swallow any further exceptions
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/lxml.etree.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/lxml.etree.h
new file mode 100644
index 0000000000000000000000000000000000000000..17b99a7be5c4159429d575c9e98f621f57c8310c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/lxml.etree.h
@@ -0,0 +1,244 @@
+/* Generated by Cython 3.1.4 */
+
+#ifndef __PYX_HAVE__lxml__etree
+#define __PYX_HAVE__lxml__etree
+
+#include "Python.h"
+struct LxmlDocument;
+struct LxmlElement;
+struct LxmlElementTree;
+struct LxmlElementTagMatcher;
+struct LxmlElementIterator;
+struct LxmlElementBase;
+struct LxmlElementClassLookup;
+struct LxmlFallbackElementClassLookup;
+
+/* "lxml/etree.pyx":451
+ *
+ * # type of a function that steps from node to node
+ * ctypedef public xmlNode* (*_node_to_node_function)(xmlNode*) # <<<<<<<<<<<<<<
+ *
+ *
+*/
+typedef xmlNode *(*_node_to_node_function)(xmlNode *);
+
+/* "lxml/etree.pyx":465
+ * # Public Python API
+ *
+ * @cython.final # <<<<<<<<<<<<<<
+ * @cython.freelist(8)
+ * cdef public class _Document [ type LxmlDocumentType, object LxmlDocument ]:
+*/
+struct LxmlDocument {
+ PyObject_HEAD
+ struct __pyx_vtabstruct_4lxml_5etree__Document *__pyx_vtab;
+ int _ns_counter;
+ PyObject *_prefix_tail;
+ xmlDoc *_c_doc;
+ struct __pyx_obj_4lxml_5etree__BaseParser *_parser;
+};
+
+/* "lxml/etree.pyx":817
+ *
+ *
+ * @cython.no_gc_clear # <<<<<<<<<<<<<<
+ * cdef public class _Element [ type LxmlElementType, object LxmlElement ]:
+ * """Element class.
+*/
+struct LxmlElement {
+ PyObject_HEAD
+ struct LxmlDocument *_doc;
+ xmlNode *_c_node;
+ PyObject *_tag;
+};
+
+/* "lxml/etree.pyx":1991
+ *
+ *
+ * cdef public class _ElementTree [ type LxmlElementTreeType, # <<<<<<<<<<<<<<
+ * object LxmlElementTree ]:
+ * cdef _Document _doc
+*/
+struct LxmlElementTree {
+ PyObject_HEAD
+ struct __pyx_vtabstruct_4lxml_5etree__ElementTree *__pyx_vtab;
+ struct LxmlDocument *_doc;
+ struct LxmlElement *_context_node;
+};
+
+/* "lxml/etree.pyx":2765
+ *
+ *
+ * cdef public class _ElementTagMatcher [ object LxmlElementTagMatcher, # <<<<<<<<<<<<<<
+ * type LxmlElementTagMatcherType ]:
+ * """
+*/
+struct LxmlElementTagMatcher {
+ PyObject_HEAD
+ struct __pyx_vtabstruct_4lxml_5etree__ElementTagMatcher *__pyx_vtab;
+ PyObject *_pystrings;
+ int _node_type;
+ char *_href;
+ char *_name;
+};
+
+/* "lxml/etree.pyx":2796
+ * self._name = NULL
+ *
+ * cdef public class _ElementIterator(_ElementTagMatcher) [ # <<<<<<<<<<<<<<
+ * object LxmlElementIterator, type LxmlElementIteratorType ]:
+ * """
+*/
+struct LxmlElementIterator {
+ struct LxmlElementTagMatcher __pyx_base;
+ struct LxmlElement *_node;
+ _node_to_node_function _next_element;
+};
+
+/* "src/lxml/classlookup.pxi":6
+ * # Custom Element classes
+ *
+ * cdef public class ElementBase(_Element) [ type LxmlElementBaseType, # <<<<<<<<<<<<<<
+ * object LxmlElementBase ]:
+ * """ElementBase(*children, attrib=None, nsmap=None, **_extra)
+*/
+struct LxmlElementBase {
+ struct LxmlElement __pyx_base;
+};
+
+/* "src/lxml/classlookup.pxi":210
+ * # Element class lookup
+ *
+ * ctypedef public object (*_element_class_lookup_function)(object, _Document, xmlNode*) # <<<<<<<<<<<<<<
+ *
+ * # class to store element class lookup functions
+*/
+typedef PyObject *(*_element_class_lookup_function)(PyObject *, struct LxmlDocument *, xmlNode *);
+
+/* "src/lxml/classlookup.pxi":213
+ *
+ * # class to store element class lookup functions
+ * cdef public class ElementClassLookup [ type LxmlElementClassLookupType, # <<<<<<<<<<<<<<
+ * object LxmlElementClassLookup ]:
+ * """ElementClassLookup(self)
+*/
+struct LxmlElementClassLookup {
+ PyObject_HEAD
+ _element_class_lookup_function _lookup_function;
+};
+
+/* "src/lxml/classlookup.pxi":221
+ *
+ *
+ * cdef public class FallbackElementClassLookup(ElementClassLookup) \ # <<<<<<<<<<<<<<
+ * [ type LxmlFallbackElementClassLookupType,
+ * object LxmlFallbackElementClassLookup ]:
+*/
+struct LxmlFallbackElementClassLookup {
+ struct LxmlElementClassLookup __pyx_base;
+ struct __pyx_vtabstruct_4lxml_5etree_FallbackElementClassLookup *__pyx_vtab;
+ struct LxmlElementClassLookup *fallback;
+ _element_class_lookup_function _fallback_function;
+};
+
+#ifndef __PYX_HAVE_API__lxml__etree
+
+#ifdef CYTHON_EXTERN_C
+ #undef __PYX_EXTERN_C
+ #define __PYX_EXTERN_C CYTHON_EXTERN_C
+#elif defined(__PYX_EXTERN_C)
+ #ifdef _MSC_VER
+ #pragma message ("Please do not define the '__PYX_EXTERN_C' macro externally. Use 'CYTHON_EXTERN_C' instead.")
+ #else
+ #warning Please do not define the '__PYX_EXTERN_C' macro externally. Use 'CYTHON_EXTERN_C' instead.
+ #endif
+#else
+ #ifdef __cplusplus
+ #define __PYX_EXTERN_C extern "C"
+ #else
+ #define __PYX_EXTERN_C extern
+ #endif
+#endif
+
+#ifndef DL_IMPORT
+ #define DL_IMPORT(_T) _T
+#endif
+
+__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlDocumentType;
+__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementType;
+__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementTreeType;
+__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementTagMatcherType;
+__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementIteratorType;
+__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementBaseType;
+__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementClassLookupType;
+__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlFallbackElementClassLookupType;
+
+__PYX_EXTERN_C struct LxmlElement *deepcopyNodeToDocument(struct LxmlDocument *, xmlNode *);
+__PYX_EXTERN_C struct LxmlElementTree *elementTreeFactory(struct LxmlElement *);
+__PYX_EXTERN_C struct LxmlElementTree *newElementTree(struct LxmlElement *, PyObject *);
+__PYX_EXTERN_C struct LxmlElementTree *adoptExternalDocument(xmlDoc *, PyObject *, int);
+__PYX_EXTERN_C struct LxmlElement *elementFactory(struct LxmlDocument *, xmlNode *);
+__PYX_EXTERN_C struct LxmlElement *makeElement(PyObject *, struct LxmlDocument *, PyObject *, PyObject *, PyObject *, PyObject *, PyObject *);
+__PYX_EXTERN_C struct LxmlElement *makeSubElement(struct LxmlElement *, PyObject *, PyObject *, PyObject *, PyObject *, PyObject *);
+__PYX_EXTERN_C void setElementClassLookupFunction(_element_class_lookup_function, PyObject *);
+__PYX_EXTERN_C PyObject *lookupDefaultElementClass(PyObject *, PyObject *, xmlNode *);
+__PYX_EXTERN_C PyObject *lookupNamespaceElementClass(PyObject *, PyObject *, xmlNode *);
+__PYX_EXTERN_C PyObject *callLookupFallback(struct LxmlFallbackElementClassLookup *, struct LxmlDocument *, xmlNode *);
+__PYX_EXTERN_C int tagMatches(xmlNode *, const xmlChar *, const xmlChar *);
+__PYX_EXTERN_C struct LxmlDocument *documentOrRaise(PyObject *);
+__PYX_EXTERN_C struct LxmlElement *rootNodeOrRaise(PyObject *);
+__PYX_EXTERN_C int hasText(xmlNode *);
+__PYX_EXTERN_C int hasTail(xmlNode *);
+__PYX_EXTERN_C PyObject *textOf(xmlNode *);
+__PYX_EXTERN_C PyObject *tailOf(xmlNode *);
+__PYX_EXTERN_C int setNodeText(xmlNode *, PyObject *);
+__PYX_EXTERN_C int setTailText(xmlNode *, PyObject *);
+__PYX_EXTERN_C PyObject *attributeValue(xmlNode *, xmlAttr *);
+__PYX_EXTERN_C PyObject *attributeValueFromNsName(xmlNode *, const xmlChar *, const xmlChar *);
+__PYX_EXTERN_C PyObject *getAttributeValue(struct LxmlElement *, PyObject *, PyObject *);
+__PYX_EXTERN_C PyObject *iterattributes(struct LxmlElement *, int);
+__PYX_EXTERN_C PyObject *collectAttributes(xmlNode *, int);
+__PYX_EXTERN_C int setAttributeValue(struct LxmlElement *, PyObject *, PyObject *);
+__PYX_EXTERN_C int delAttribute(struct LxmlElement *, PyObject *);
+__PYX_EXTERN_C int delAttributeFromNsName(xmlNode *, const xmlChar *, const xmlChar *);
+__PYX_EXTERN_C int hasChild(xmlNode *);
+__PYX_EXTERN_C xmlNode *findChild(xmlNode *, Py_ssize_t);
+__PYX_EXTERN_C xmlNode *findChildForwards(xmlNode *, Py_ssize_t);
+__PYX_EXTERN_C xmlNode *findChildBackwards(xmlNode *, Py_ssize_t);
+__PYX_EXTERN_C xmlNode *nextElement(xmlNode *);
+__PYX_EXTERN_C xmlNode *previousElement(xmlNode *);
+__PYX_EXTERN_C void appendChild(struct LxmlElement *, struct LxmlElement *);
+__PYX_EXTERN_C int appendChildToElement(struct LxmlElement *, struct LxmlElement *);
+__PYX_EXTERN_C PyObject *pyunicode(const xmlChar *);
+__PYX_EXTERN_C PyObject *utf8(PyObject *);
+__PYX_EXTERN_C PyObject *getNsTag(PyObject *);
+__PYX_EXTERN_C PyObject *getNsTagWithEmptyNs(PyObject *);
+__PYX_EXTERN_C PyObject *namespacedName(xmlNode *);
+__PYX_EXTERN_C PyObject *namespacedNameFromNsName(const xmlChar *, const xmlChar *);
+__PYX_EXTERN_C void iteratorStoreNext(struct LxmlElementIterator *, struct LxmlElement *);
+__PYX_EXTERN_C void initTagMatch(struct LxmlElementTagMatcher *, PyObject *);
+__PYX_EXTERN_C xmlNs *findOrBuildNodeNsPrefix(struct LxmlDocument *, xmlNode *, const xmlChar *, const xmlChar *);
+
+#endif /* !__PYX_HAVE_API__lxml__etree */
+
+/* WARNING: the interface of the module init function changed in CPython 3.5. */
+/* It now returns a PyModuleDef instance instead of a PyModule instance. */
+
+/* WARNING: Use PyImport_AppendInittab("etree", PyInit_etree) instead of calling PyInit_etree directly from Python 3.5 */
+PyMODINIT_FUNC PyInit_etree(void);
+
+#if PY_VERSION_HEX >= 0x03050000 && (defined(__GNUC__) || defined(__clang__) || defined(_MSC_VER) || (defined(__cplusplus) && __cplusplus >= 201402L))
+#if defined(__cplusplus) && __cplusplus >= 201402L
+[[deprecated("Use PyImport_AppendInittab(\"etree\", PyInit_etree) instead of calling PyInit_etree directly.")]] inline
+#elif defined(__GNUC__) || defined(__clang__)
+__attribute__ ((__deprecated__("Use PyImport_AppendInittab(\"etree\", PyInit_etree) instead of calling PyInit_etree directly."), __unused__)) __inline__
+#elif defined(_MSC_VER)
+__declspec(deprecated("Use PyImport_AppendInittab(\"etree\", PyInit_etree) instead of calling PyInit_etree directly.")) __inline
+#endif
+static PyObject* __PYX_WARN_IF_PyInit_etree_INIT_CALLED(PyObject* res) {
+ return res;
+}
+#define PyInit_etree() __PYX_WARN_IF_PyInit_etree_INIT_CALLED(PyInit_etree())
+#endif
+
+#endif /* !__PYX_HAVE__lxml__etree */
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/objectpath.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/objectpath.pxi
new file mode 100644
index 0000000000000000000000000000000000000000..e562a365015830bfd3d24650d1109fe891c31039
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/objectpath.pxi
@@ -0,0 +1,332 @@
+################################################################################
+# ObjectPath
+
+ctypedef struct _ObjectPath:
+ const_xmlChar* href
+ const_xmlChar* name
+ Py_ssize_t index
+
+
+cdef object _NO_DEFAULT = object()
+
+
+cdef class ObjectPath:
+ """ObjectPath(path)
+ Immutable object that represents a compiled object path.
+
+ Example for a path: 'root.child[1].{other}child[25]'
+ """
+ cdef readonly object find
+ cdef list _path
+ cdef object _path_str
+ cdef _ObjectPath* _c_path
+ cdef Py_ssize_t _path_len
+ def __init__(self, path):
+ if python._isString(path):
+ self._path = _parse_object_path_string(path)
+ self._path_str = path
+ else:
+ self._path = _parse_object_path_list(path)
+ self._path_str = '.'.join(path)
+ self._path_len = len(self._path)
+ self._c_path = _build_object_path_segments(self._path)
+ self.find = self.__call__
+
+ def __dealloc__(self):
+ if self._c_path is not NULL:
+ python.lxml_free(self._c_path)
+
+ def __str__(self):
+ return self._path_str
+
+ def __call__(self, _Element root not None, *_default):
+ """Follow the attribute path in the object structure and return the
+ target attribute value.
+
+ If it it not found, either returns a default value (if one was passed
+ as second argument) or raises AttributeError.
+ """
+ if _default:
+ if len(_default) > 1:
+ raise TypeError, "invalid number of arguments: needs one or two"
+ default = _default[0]
+ else:
+ default = _NO_DEFAULT
+ return _find_object_path(root, self._c_path, self._path_len, default)
+
+ def hasattr(self, _Element root not None):
+ "hasattr(self, root)"
+ try:
+ _find_object_path(root, self._c_path, self._path_len, _NO_DEFAULT)
+ except AttributeError:
+ return False
+ return True
+
+ def setattr(self, _Element root not None, value):
+ """setattr(self, root, value)
+
+ Set the value of the target element in a subtree.
+
+ If any of the children on the path does not exist, it is created.
+ """
+ _create_object_path(root, self._c_path, self._path_len, 1, value)
+
+ def addattr(self, _Element root not None, value):
+ """addattr(self, root, value)
+
+ Append a value to the target element in a subtree.
+
+ If any of the children on the path does not exist, it is created.
+ """
+ _create_object_path(root, self._c_path, self._path_len, 0, value)
+
+
+cdef object __MATCH_PATH_SEGMENT = re.compile(
+ r"(\.?)\s*(?:\{([^}]*)\})?\s*([^.{}\[\]\s]+)\s*(?:\[\s*([-0-9]+)\s*\])?",
+ re.U).match
+
+cdef tuple _RELATIVE_PATH_SEGMENT = (None, None, 0)
+
+
+cdef list _parse_object_path_string(_path):
+ """Parse object path string into a (ns, name, index) list.
+ """
+ cdef bint has_dot
+ cdef unicode path
+ new_path = []
+ if isinstance(_path, bytes):
+ path = (_path).decode('ascii')
+ elif type(_path) is not unicode:
+ path = unicode(_path)
+ else:
+ path = _path
+ path = path.strip()
+ if path == '.':
+ return [_RELATIVE_PATH_SEGMENT]
+ path_pos = 0
+ while path:
+ match = __MATCH_PATH_SEGMENT(path, path_pos)
+ if match is None:
+ break
+
+ dot, ns, name, index = match.groups()
+ index = int(index) if index else 0
+ has_dot = dot == '.'
+ if not new_path:
+ if has_dot:
+ # path '.child' => ignore root
+ new_path.append(_RELATIVE_PATH_SEGMENT)
+ elif index:
+ raise ValueError, "index not allowed on root node"
+ elif not has_dot:
+ raise ValueError, "invalid path"
+ if ns is not None:
+ ns = python.PyUnicode_AsUTF8String(ns)
+ name = python.PyUnicode_AsUTF8String(name)
+ new_path.append( (ns, name, index) )
+
+ path_pos = match.end()
+ if not new_path or len(path) > path_pos:
+ raise ValueError, "invalid path"
+ return new_path
+
+
+cdef list _parse_object_path_list(path):
+ """Parse object path sequence into a (ns, name, index) list.
+ """
+ new_path = []
+ for item in path:
+ item = item.strip()
+ if not new_path and item == '':
+ # path '.child' => ignore root
+ ns = name = None
+ index = 0
+ else:
+ ns, name = cetree.getNsTag(item)
+ c_name = _xcstr(name)
+ index_pos = tree.xmlStrchr(c_name, c'[')
+ if index_pos is NULL:
+ index = 0
+ else:
+ index_end = tree.xmlStrchr(index_pos + 1, c']')
+ if index_end is NULL:
+ raise ValueError, "index must be enclosed in []"
+ index = int(index_pos[1:index_end - index_pos])
+ if not new_path and index != 0:
+ raise ValueError, "index not allowed on root node"
+ name = c_name[:index_pos - c_name]
+ new_path.append( (ns, name, index) )
+ if not new_path:
+ raise ValueError, "invalid path"
+ return new_path
+
+
+cdef _ObjectPath* _build_object_path_segments(list path_list) except NULL:
+ cdef _ObjectPath* c_path
+ cdef _ObjectPath* c_path_segments
+ c_path_segments = <_ObjectPath*>python.lxml_malloc(len(path_list), sizeof(_ObjectPath))
+ if c_path_segments is NULL:
+ raise MemoryError()
+ c_path = c_path_segments
+ for href, name, index in path_list:
+ c_path[0].href = _xcstr(href) if href is not None else NULL
+ c_path[0].name = _xcstr(name) if name is not None else NULL
+ c_path[0].index = index
+ c_path += 1
+ return c_path_segments
+
+
+cdef _find_object_path(_Element root, _ObjectPath* c_path, Py_ssize_t c_path_len, default_value):
+ """Follow the path to find the target element.
+ """
+ cdef tree.xmlNode* c_node
+ cdef Py_ssize_t c_index
+ c_node = root._c_node
+ c_name = c_path[0].name
+ c_href = c_path[0].href
+ if c_href is NULL or c_href[0] == c'\0':
+ c_href = tree._getNs(c_node)
+ if not cetree.tagMatches(c_node, c_href, c_name):
+ if default_value is not _NO_DEFAULT:
+ return default_value
+ else:
+ raise ValueError(
+ f"root element does not match: need {cetree.namespacedNameFromNsName(c_href, c_name)}, got {root.tag}")
+
+ while c_node is not NULL:
+ c_path_len -= 1
+ if c_path_len <= 0:
+ break
+
+ c_path += 1
+ if c_path[0].href is not NULL:
+ c_href = c_path[0].href # otherwise: keep parent namespace
+ c_name = tree.xmlDictExists(c_node.doc.dict, c_path[0].name, -1)
+ if c_name is NULL:
+ c_name = c_path[0].name
+ c_node = NULL
+ break
+ c_index = c_path[0].index
+ c_node = c_node.last if c_index < 0 else c_node.children
+ c_node = _findFollowingSibling(c_node, c_href, c_name, c_index)
+
+ if c_node is not NULL:
+ return cetree.elementFactory(root._doc, c_node)
+ elif default_value is not _NO_DEFAULT:
+ return default_value
+ else:
+ tag = cetree.namespacedNameFromNsName(c_href, c_name)
+ raise AttributeError, f"no such child: {tag}"
+
+
+cdef _create_object_path(_Element root, _ObjectPath* c_path,
+ Py_ssize_t c_path_len, int replace, value):
+ """Follow the path to find the target element, build the missing children
+ as needed and set the target element to 'value'. If replace is true, an
+ existing value is replaced, otherwise the new value is added.
+ """
+ cdef _Element child
+ cdef tree.xmlNode* c_node
+ cdef tree.xmlNode* c_child
+ cdef Py_ssize_t c_index
+ if c_path_len == 1:
+ raise TypeError, "cannot update root node"
+
+ c_node = root._c_node
+ c_name = c_path[0].name
+ c_href = c_path[0].href
+ if c_href is NULL or c_href[0] == c'\0':
+ c_href = tree._getNs(c_node)
+ if not cetree.tagMatches(c_node, c_href, c_name):
+ raise ValueError(
+ f"root element does not match: need {cetree.namespacedNameFromNsName(c_href, c_name)}, got {root.tag}")
+
+ while c_path_len > 1:
+ c_path_len -= 1
+ c_path += 1
+ if c_path[0].href is not NULL:
+ c_href = c_path[0].href # otherwise: keep parent namespace
+ c_index = c_path[0].index
+ c_name = tree.xmlDictExists(c_node.doc.dict, c_path[0].name, -1)
+ if c_name is NULL:
+ c_name = c_path[0].name
+ c_child = NULL
+ else:
+ c_child = c_node.last if c_index < 0 else c_node.children
+ c_child = _findFollowingSibling(c_child, c_href, c_name, c_index)
+
+ if c_child is not NULL:
+ c_node = c_child
+ elif c_index != 0:
+ raise TypeError, "creating indexed path attributes is not supported"
+ elif c_path_len == 1:
+ _appendValue(cetree.elementFactory(root._doc, c_node),
+ cetree.namespacedNameFromNsName(c_href, c_name),
+ value)
+ return
+ else:
+ child = cetree.makeSubElement(
+ cetree.elementFactory(root._doc, c_node),
+ cetree.namespacedNameFromNsName(c_href, c_name),
+ None, None, None, None)
+ c_node = child._c_node
+
+ # if we get here, the entire path was already there
+ if replace:
+ element = cetree.elementFactory(root._doc, c_node)
+ _replaceElement(element, value)
+ else:
+ _appendValue(cetree.elementFactory(root._doc, c_node.parent),
+ cetree.namespacedName(c_node), value)
+
+
+cdef list _build_descendant_paths(tree.xmlNode* c_node, prefix_string):
+ """Returns a list of all descendant paths.
+ """
+ cdef list path, path_list
+ tag = cetree.namespacedName(c_node)
+ if prefix_string:
+ if prefix_string[-1] != '.':
+ prefix_string += '.'
+ prefix_string = prefix_string + tag
+ else:
+ prefix_string = tag
+ path = [prefix_string]
+ path_list = []
+ _recursive_build_descendant_paths(c_node, path, path_list)
+ return path_list
+
+
+cdef int _recursive_build_descendant_paths(tree.xmlNode* c_node,
+ list path, list path_list) except -1:
+ """Fills the list 'path_list' with all descendant paths, initial prefix
+ being in the list 'path'.
+ """
+ cdef tree.xmlNode* c_child
+ tags = {}
+ path_list.append('.'.join(path))
+ c_href = tree._getNs(c_node)
+ c_child = c_node.children
+ while c_child is not NULL:
+ while c_child.type != tree.XML_ELEMENT_NODE:
+ c_child = c_child.next
+ if c_child is NULL:
+ return 0
+ if c_href is tree._getNs(c_child):
+ tag = pyunicode(c_child.name)
+ elif c_href is not NULL and tree._getNs(c_child) is NULL:
+ # special case: parent has namespace, child does not
+ tag = '{}' + pyunicode(c_child.name)
+ else:
+ tag = cetree.namespacedName(c_child)
+ count = tags.get(tag)
+ if count is None:
+ tags[tag] = 1
+ else:
+ tags[tag] = count + 1
+ tag += f'[{count}]'
+ path.append(tag)
+ _recursive_build_descendant_paths(c_child, path, path_list)
+ del path[-1]
+ c_child = c_child.next
+ return 0
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/proxy.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/proxy.pxi
new file mode 100644
index 0000000000000000000000000000000000000000..0e6cf19ef7ffadcb061aff55cdd6acb4d2ccfefa
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/proxy.pxi
@@ -0,0 +1,622 @@
+# Proxy functions and low level node allocation stuff
+
+# Proxies represent elements, their reference is stored in the C
+# structure of the respective node to avoid multiple instantiation of
+# the Python class.
+
+@cython.linetrace(False)
+@cython.profile(False)
+cdef inline _Element getProxy(xmlNode* c_node):
+ """Get a proxy for a given node.
+ """
+ #print "getProxy for:", c_node
+ if c_node is not NULL and c_node._private is not NULL:
+ return <_Element>c_node._private
+ else:
+ return None
+
+
+@cython.linetrace(False)
+@cython.profile(False)
+cdef inline bint hasProxy(xmlNode* c_node):
+ if c_node._private is NULL:
+ return False
+ return True
+
+
+@cython.linetrace(False)
+@cython.profile(False)
+cdef inline int _registerProxy(_Element proxy, _Document doc,
+ xmlNode* c_node) except -1:
+ """Register a proxy and type for the node it's proxying for.
+ """
+ #print "registering for:", proxy._c_node
+ assert not hasProxy(c_node), "double registering proxy!"
+ proxy._doc = doc
+ proxy._c_node = c_node
+ c_node._private = proxy
+ return 0
+
+
+@cython.linetrace(False)
+@cython.profile(False)
+cdef inline int _unregisterProxy(_Element proxy) except -1:
+ """Unregister a proxy for the node it's proxying for.
+ """
+ cdef xmlNode* c_node = proxy._c_node
+ assert c_node._private is proxy, "Tried to unregister unknown proxy"
+ c_node._private = NULL
+ return 0
+
+
+################################################################################
+# temporarily make a node the root node of its document
+
+cdef xmlDoc* _fakeRootDoc(xmlDoc* c_base_doc, xmlNode* c_node) except NULL:
+ return _plainFakeRootDoc(c_base_doc, c_node, 1)
+
+cdef xmlDoc* _plainFakeRootDoc(xmlDoc* c_base_doc, xmlNode* c_node,
+ bint with_siblings) except NULL:
+ # build a temporary document that has the given node as root node
+ # note that copy and original must not be modified during its lifetime!!
+ # always call _destroyFakeDoc() after use!
+ cdef xmlNode* c_child
+ cdef xmlNode* c_root
+ cdef xmlNode* c_new_root
+ cdef xmlDoc* c_doc
+ if with_siblings or (c_node.prev is NULL and c_node.next is NULL):
+ c_root = tree.xmlDocGetRootElement(c_base_doc)
+ if c_root is c_node:
+ # already the root node, no siblings
+ return c_base_doc
+
+ c_doc = _copyDoc(c_base_doc, 0) # non recursive!
+ c_new_root = tree.xmlDocCopyNode(c_node, c_doc, 2) # non recursive!
+ tree.xmlDocSetRootElement(c_doc, c_new_root)
+ _copyParentNamespaces(c_node, c_new_root)
+
+ c_new_root.children = c_node.children
+ c_new_root.last = c_node.last
+ c_new_root.next = c_new_root.prev = NULL
+
+ # store original node
+ c_doc._private = c_node
+
+ # divert parent pointers of children
+ c_child = c_new_root.children
+ while c_child is not NULL:
+ c_child.parent = c_new_root
+ c_child = c_child.next
+
+ c_doc.children = c_new_root
+ return c_doc
+
+cdef void _destroyFakeDoc(xmlDoc* c_base_doc, xmlDoc* c_doc) noexcept:
+ # delete a temporary document
+ cdef xmlNode* c_child
+ cdef xmlNode* c_parent
+ cdef xmlNode* c_root
+ if c_doc is c_base_doc:
+ return
+ c_root = tree.xmlDocGetRootElement(c_doc)
+
+ # restore parent pointers of children
+ c_parent = c_doc._private
+ c_child = c_root.children
+ while c_child is not NULL:
+ c_child.parent = c_parent
+ c_child = c_child.next
+
+ # prevent recursive removal of children
+ c_root.children = c_root.last = NULL
+ tree.xmlFreeDoc(c_doc)
+
+cdef _Element _fakeDocElementFactory(_Document doc, xmlNode* c_element):
+ """Special element factory for cases where we need to create a fake
+ root document, but still need to instantiate arbitrary nodes from
+ it. If we instantiate the fake root node, things will turn bad
+ when it's destroyed.
+
+ Instead, if we are asked to instantiate the fake root node, we
+ instantiate the original node instead.
+ """
+ if c_element.doc is not doc._c_doc:
+ if c_element.doc._private is not NULL:
+ if c_element is c_element.doc.children:
+ c_element = c_element.doc._private
+ #assert c_element.type == tree.XML_ELEMENT_NODE
+ return _elementFactory(doc, c_element)
+
+################################################################################
+# support for freeing tree elements when proxy objects are destroyed
+
+cdef int attemptDeallocation(xmlNode* c_node) noexcept:
+ """Attempt deallocation of c_node (or higher up in tree).
+ """
+ cdef xmlNode* c_top
+ # could be we actually aren't referring to the tree at all
+ if c_node is NULL:
+ #print "not freeing, node is NULL"
+ return 0
+ c_top = getDeallocationTop(c_node)
+ if c_top is not NULL:
+ #print "freeing:", c_top.name
+ _removeText(c_top.next) # tail
+ tree.xmlFreeNode(c_top)
+ return 1
+ return 0
+
+cdef xmlNode* getDeallocationTop(xmlNode* c_node) noexcept:
+ """Return the top of the tree that can be deallocated, or NULL.
+ """
+ cdef xmlNode* c_next
+ #print "trying to do deallocating:", c_node.type
+ if hasProxy(c_node):
+ #print "Not freeing: proxies still exist"
+ return NULL
+ while c_node.parent is not NULL:
+ c_node = c_node.parent
+ #print "checking:", c_current.type
+ if c_node.type == tree.XML_DOCUMENT_NODE or \
+ c_node.type == tree.XML_HTML_DOCUMENT_NODE:
+ #print "not freeing: still in doc"
+ return NULL
+ # if we're still attached to the document, don't deallocate
+ if hasProxy(c_node):
+ #print "Not freeing: proxies still exist"
+ return NULL
+ # see whether we have children to deallocate
+ if not canDeallocateChildNodes(c_node):
+ return NULL
+ # see whether we have siblings to deallocate
+ c_next = c_node.prev
+ while c_next:
+ if _isElement(c_next):
+ if hasProxy(c_next) or not canDeallocateChildNodes(c_next):
+ return NULL
+ c_next = c_next.prev
+ c_next = c_node.next
+ while c_next:
+ if _isElement(c_next):
+ if hasProxy(c_next) or not canDeallocateChildNodes(c_next):
+ return NULL
+ c_next = c_next.next
+ return c_node
+
+cdef int canDeallocateChildNodes(xmlNode* c_parent) noexcept:
+ cdef xmlNode* c_node
+ c_node = c_parent.children
+ tree.BEGIN_FOR_EACH_ELEMENT_FROM(c_parent, c_node, 1)
+ if hasProxy(c_node):
+ return 0
+ tree.END_FOR_EACH_ELEMENT_FROM(c_node)
+ return 1
+
+################################################################################
+# fix _Document references and namespaces when a node changes documents
+
+cdef void _copyParentNamespaces(xmlNode* c_from_node, xmlNode* c_to_node) noexcept nogil:
+ """Copy the namespaces of all ancestors of c_from_node to c_to_node.
+ """
+ cdef xmlNode* c_parent
+ cdef xmlNs* c_ns
+ cdef xmlNs* c_new_ns
+ cdef int prefix_known
+ c_parent = c_from_node.parent
+ while c_parent and (tree._isElementOrXInclude(c_parent) or
+ c_parent.type == tree.XML_DOCUMENT_NODE):
+ c_new_ns = c_parent.nsDef
+ while c_new_ns:
+ # libxml2 will check if the prefix is already defined
+ tree.xmlNewNs(c_to_node, c_new_ns.href, c_new_ns.prefix)
+ c_new_ns = c_new_ns.next
+ c_parent = c_parent.parent
+
+
+ctypedef struct _ns_update_map:
+ xmlNs* old
+ xmlNs* new
+
+
+ctypedef struct _nscache:
+ _ns_update_map* ns_map
+ size_t size
+ size_t last
+
+
+cdef int _growNsCache(_nscache* c_ns_cache) except -1:
+ cdef _ns_update_map* ns_map_ptr
+ if c_ns_cache.size == 0:
+ c_ns_cache.size = 20
+ else:
+ c_ns_cache.size *= 2
+ ns_map_ptr = <_ns_update_map*> python.lxml_realloc(
+ c_ns_cache.ns_map, c_ns_cache.size, sizeof(_ns_update_map))
+ if not ns_map_ptr:
+ python.lxml_free(c_ns_cache.ns_map)
+ c_ns_cache.ns_map = NULL
+ raise MemoryError()
+ c_ns_cache.ns_map = ns_map_ptr
+ return 0
+
+
+cdef inline int _appendToNsCache(_nscache* c_ns_cache,
+ xmlNs* c_old_ns, xmlNs* c_new_ns) except -1:
+ if c_ns_cache.last >= c_ns_cache.size:
+ _growNsCache(c_ns_cache)
+ c_ns_cache.ns_map[c_ns_cache.last] = _ns_update_map(old=c_old_ns, new=c_new_ns)
+ c_ns_cache.last += 1
+
+
+cdef int _stripRedundantNamespaceDeclarations(xmlNode* c_element, _nscache* c_ns_cache,
+ xmlNs** c_del_ns_list) except -1:
+ """Removes namespace declarations from an element that are already
+ defined in its parents. Does not free the xmlNs's, just prepends
+ them to the c_del_ns_list.
+ """
+ cdef xmlNs* c_ns
+ cdef xmlNs* c_ns_next
+ cdef xmlNs** c_nsdef
+ # use a xmlNs** to handle assignments to "c_element.nsDef" correctly
+ c_nsdef = &c_element.nsDef
+ while c_nsdef[0] is not NULL:
+ c_ns = tree.xmlSearchNsByHref(
+ c_element.doc, c_element.parent, c_nsdef[0].href)
+ if c_ns is NULL:
+ # new namespace href => keep and cache the ns declaration
+ _appendToNsCache(c_ns_cache, c_nsdef[0], c_nsdef[0])
+ c_nsdef = &c_nsdef[0].next
+ else:
+ # known namespace href => cache mapping and strip old ns
+ _appendToNsCache(c_ns_cache, c_nsdef[0], c_ns)
+ # cut out c_nsdef.next and prepend it to garbage chain
+ c_ns_next = c_nsdef[0].next
+ c_nsdef[0].next = c_del_ns_list[0]
+ c_del_ns_list[0] = c_nsdef[0]
+ c_nsdef[0] = c_ns_next
+ return 0
+
+
+cdef void _cleanUpFromNamespaceAdaptation(xmlNode* c_start_node,
+ _nscache* c_ns_cache, xmlNs* c_del_ns_list) noexcept:
+ # Try to recover from exceptions with really bad timing. We were in the middle
+ # of ripping out xmlNS-es and likely ran out of memory. Try to fix up the tree
+ # by re-adding the original xmlNs declarations (which might still be used in some
+ # places).
+ if c_ns_cache.ns_map:
+ python.lxml_free(c_ns_cache.ns_map)
+ if c_del_ns_list:
+ if not c_start_node.nsDef:
+ c_start_node.nsDef = c_del_ns_list
+ else:
+ c_ns = c_start_node.nsDef
+ while c_ns.next:
+ c_ns = c_ns.next
+ c_ns.next = c_del_ns_list
+
+
+cdef int moveNodeToDocument(_Document doc, xmlDoc* c_source_doc,
+ xmlNode* c_element) except -1:
+ """Fix the xmlNs pointers of a node and its subtree that were moved.
+
+ Originally copied from libxml2's xmlReconciliateNs(). Expects
+ libxml2 doc pointers of node to be correct already, but fixes
+ _Document references.
+
+ For each node in the subtree, we do this:
+
+ 1) Remove redundant declarations of namespace that are already
+ defined in its parents.
+
+ 2) Replace namespaces that are *not* defined on the node or its
+ parents by the equivalent namespace declarations that *are*
+ defined on the node or its parents (possibly using a different
+ prefix). If a namespace is unknown, declare a new one on the
+ node.
+
+ 3) Reassign the names of tags and attribute from the dict of the
+ target document *iff* it is different from the dict used in the
+ source subtree.
+
+ 4) Set the Document reference to the new Document (if different).
+ This is done on backtracking to keep the original Document
+ alive as long as possible, until all its elements are updated.
+
+ Note that the namespace declarations are removed from the tree in
+ step 1), but freed only after the complete subtree was traversed
+ and all occurrences were replaced by tree-internal pointers.
+ """
+ cdef xmlNode* c_start_node
+ cdef xmlNode* c_node
+ cdef xmlDoc* c_doc = doc._c_doc
+ cdef tree.xmlAttr* c_attr
+ cdef char* c_name
+ cdef _nscache c_ns_cache = [NULL, 0, 0]
+ cdef xmlNs* c_del_ns_list = NULL
+ cdef proxy_count = 0
+
+ if not tree._isElementOrXInclude(c_element):
+ return 0
+
+ c_start_node = c_element
+
+ tree.BEGIN_FOR_EACH_FROM(c_element, c_element, 1)
+ if tree._isElementOrXInclude(c_element):
+ if hasProxy(c_element):
+ proxy_count += 1
+
+ # 1) cut out namespaces defined here that are already known by
+ # the ancestors
+ if c_element.nsDef is not NULL:
+ try:
+ _stripRedundantNamespaceDeclarations(c_element, &c_ns_cache, &c_del_ns_list)
+ except:
+ _cleanUpFromNamespaceAdaptation(c_start_node, &c_ns_cache, c_del_ns_list)
+ raise
+
+ # 2) make sure the namespaces of an element and its attributes
+ # are declared in this document (i.e. on the node or its parents)
+ if c_element.ns is not NULL:
+ _fixCNs(doc, c_start_node, c_element, &c_ns_cache, c_del_ns_list)
+
+ c_node = c_element.properties
+ while c_node is not NULL:
+ if c_node.ns is not NULL:
+ _fixCNs(doc, c_start_node, c_node, &c_ns_cache, c_del_ns_list)
+ c_node = c_node.next
+
+ tree.END_FOR_EACH_FROM(c_element)
+
+ # free now unused namespace declarations
+ if c_del_ns_list is not NULL:
+ tree.xmlFreeNsList(c_del_ns_list)
+
+ # cleanup
+ if c_ns_cache.ns_map is not NULL:
+ python.lxml_free(c_ns_cache.ns_map)
+
+ # 3) fix the names in the tree if we moved it from a different thread
+ if doc._c_doc.dict is not c_source_doc.dict:
+ fixThreadDictNames(c_start_node, c_source_doc.dict, doc._c_doc.dict)
+
+ # 4) fix _Document references
+ # (and potentially deallocate the source document)
+ if proxy_count > 0:
+ if proxy_count == 1 and c_start_node._private is not NULL:
+ proxy = getProxy(c_start_node)
+ if proxy is not None:
+ if proxy._doc is not doc:
+ proxy._doc = doc
+ else:
+ fixElementDocument(c_start_node, doc, proxy_count)
+ else:
+ fixElementDocument(c_start_node, doc, proxy_count)
+
+ return 0
+
+
+cdef void _setTreeDoc(xmlNode* c_node, xmlDoc* c_doc) noexcept:
+ """Adaptation of 'xmlSetTreeDoc()' that deep-fixes the document links iteratively.
+ It avoids https://gitlab.gnome.org/GNOME/libxml2/issues/42
+ """
+ tree.BEGIN_FOR_EACH_FROM(c_node, c_node, 1)
+ if c_node.type == tree.XML_ELEMENT_NODE:
+ c_attr = c_node.properties
+ while c_attr:
+ if c_attr.atype == tree.XML_ATTRIBUTE_ID:
+ tree.xmlRemoveID(c_node.doc, c_attr)
+ c_attr.doc = c_doc
+ _fixDocChildren(c_attr.children, c_doc)
+ c_attr = c_attr.next
+ # Set doc link for all nodes, not only elements.
+ c_node.doc = c_doc
+ tree.END_FOR_EACH_FROM(c_node)
+
+
+cdef inline void _fixDocChildren(xmlNode* c_child, xmlDoc* c_doc) noexcept:
+ while c_child:
+ c_child.doc = c_doc
+ if c_child.children:
+ _fixDocChildren(c_child.children, c_doc)
+ c_child = c_child.next
+
+
+cdef int _fixCNs(_Document doc, xmlNode* c_start_node, xmlNode* c_node,
+ _nscache* c_ns_cache, xmlNs* c_del_ns_list) except -1:
+ cdef xmlNs* c_ns = NULL
+ cdef bint is_prefixed_attr = (c_node.type == tree.XML_ATTRIBUTE_NODE and c_node.ns.prefix)
+
+ for ns_map in c_ns_cache.ns_map[:c_ns_cache.last]:
+ if c_node.ns is ns_map.old:
+ if is_prefixed_attr and not ns_map.new.prefix:
+ # avoid dropping prefix from attributes
+ continue
+ c_ns = ns_map.new
+ break
+
+ if c_ns:
+ c_node.ns = c_ns
+ else:
+ # not in cache or not acceptable
+ # => find a replacement from this document
+ try:
+ c_ns = doc._findOrBuildNodeNs(
+ c_start_node, c_node.ns.href, c_node.ns.prefix,
+ c_node.type == tree.XML_ATTRIBUTE_NODE)
+ c_node.ns = c_ns
+ _appendToNsCache(c_ns_cache, c_node.ns, c_ns)
+ except:
+ _cleanUpFromNamespaceAdaptation(c_start_node, c_ns_cache, c_del_ns_list)
+ raise
+ return 0
+
+
+cdef int fixElementDocument(xmlNode* c_element, _Document doc,
+ size_t proxy_count) except -1:
+ cdef xmlNode* c_node = c_element
+ cdef _Element proxy = None # init-to-None required due to fake-loop below
+ tree.BEGIN_FOR_EACH_FROM(c_element, c_node, 1)
+ if c_node._private is not NULL:
+ proxy = getProxy(c_node)
+ if proxy is not None:
+ if proxy._doc is not doc:
+ proxy._doc = doc
+ proxy_count -= 1
+ if proxy_count == 0:
+ return 0
+ tree.END_FOR_EACH_FROM(c_node)
+
+
+cdef void fixThreadDictNames(xmlNode* c_element,
+ tree.xmlDict* c_src_dict,
+ tree.xmlDict* c_dict) noexcept nogil:
+ # re-assign the names of tags and attributes
+ #
+ # this should only be called when the element is based on a
+ # different libxml2 tag name dictionary
+ if c_element.type == tree.XML_DOCUMENT_NODE or \
+ c_element.type == tree.XML_HTML_DOCUMENT_NODE:
+ # may define "xml" namespace
+ fixThreadDictNsForNode(c_element, c_src_dict, c_dict)
+ if c_element.doc.extSubset:
+ fixThreadDictNamesForDtd(c_element.doc.extSubset, c_src_dict, c_dict)
+ if c_element.doc.intSubset:
+ fixThreadDictNamesForDtd(c_element.doc.intSubset, c_src_dict, c_dict)
+ c_element = c_element.children
+ while c_element is not NULL:
+ fixThreadDictNamesForNode(c_element, c_src_dict, c_dict)
+ c_element = c_element.next
+ elif tree._isElementOrXInclude(c_element):
+ fixThreadDictNamesForNode(c_element, c_src_dict, c_dict)
+
+
+cdef inline void _fixThreadDictPtr(const_xmlChar** c_ptr,
+ tree.xmlDict* c_src_dict,
+ tree.xmlDict* c_dict) noexcept nogil:
+ c_str = c_ptr[0]
+ if c_str and c_src_dict and tree.xmlDictOwns(c_src_dict, c_str):
+ # return value can be NULL on memory error, but we don't handle that here
+ c_str = tree.xmlDictLookup(c_dict, c_str, -1)
+ if c_str:
+ c_ptr[0] = c_str
+
+
+cdef void fixThreadDictNamesForNode(xmlNode* c_element,
+ tree.xmlDict* c_src_dict,
+ tree.xmlDict* c_dict) noexcept nogil:
+ cdef xmlNode* c_node = c_element
+ tree.BEGIN_FOR_EACH_FROM(c_element, c_node, 1)
+ if c_node.type in (tree.XML_ELEMENT_NODE, tree.XML_XINCLUDE_START):
+ fixThreadDictNamesForAttributes(
+ c_node.properties, c_src_dict, c_dict)
+ fixThreadDictNsForNode(c_node, c_src_dict, c_dict)
+ _fixThreadDictPtr(&c_node.name, c_src_dict, c_dict)
+ elif c_node.type == tree.XML_TEXT_NODE:
+ # libxml2's SAX2 parser interns some indentation space
+ fixThreadDictContentForNode(c_node, c_src_dict, c_dict)
+ elif c_node.type == tree.XML_COMMENT_NODE:
+ pass # don't touch c_node.name
+ else:
+ _fixThreadDictPtr(&c_node.name, c_src_dict, c_dict)
+ tree.END_FOR_EACH_FROM(c_node)
+
+
+cdef inline void fixThreadDictNamesForAttributes(tree.xmlAttr* c_attr,
+ tree.xmlDict* c_src_dict,
+ tree.xmlDict* c_dict) noexcept nogil:
+ cdef xmlNode* c_child
+ cdef xmlNode* c_node = c_attr
+ while c_node is not NULL:
+ if c_node.type not in (tree.XML_TEXT_NODE, tree.XML_COMMENT_NODE):
+ _fixThreadDictPtr(&c_node.name, c_src_dict, c_dict)
+ # libxml2 keeps some (!) attribute values in the dict
+ c_child = c_node.children
+ while c_child is not NULL:
+ fixThreadDictContentForNode(c_child, c_src_dict, c_dict)
+ c_child = c_child.next
+ c_node = c_node.next
+
+
+cdef inline void fixThreadDictContentForNode(xmlNode* c_node,
+ tree.xmlDict* c_src_dict,
+ tree.xmlDict* c_dict) noexcept nogil:
+ if c_node.content is not NULL and \
+ c_node.content is not &c_node.properties:
+ if tree.xmlDictOwns(c_src_dict, c_node.content):
+ # result can be NULL on memory error, but we don't handle that here
+ c_node.content = tree.xmlDictLookup(c_dict, c_node.content, -1)
+
+
+cdef inline void fixThreadDictNsForNode(xmlNode* c_node,
+ tree.xmlDict* c_src_dict,
+ tree.xmlDict* c_dict) noexcept nogil:
+ cdef xmlNs* c_ns = c_node.nsDef
+ while c_ns is not NULL:
+ _fixThreadDictPtr(&c_ns.href, c_src_dict, c_dict)
+ _fixThreadDictPtr(&c_ns.prefix, c_src_dict, c_dict)
+ c_ns = c_ns.next
+
+
+cdef void fixThreadDictNamesForDtd(tree.xmlDtd* c_dtd,
+ tree.xmlDict* c_src_dict,
+ tree.xmlDict* c_dict) noexcept nogil:
+ cdef xmlNode* c_node
+ cdef tree.xmlElement* c_element
+ cdef tree.xmlAttribute* c_attribute
+ cdef tree.xmlEntity* c_entity
+
+ c_node = c_dtd.children
+ while c_node:
+ if c_node.type == tree.XML_ELEMENT_DECL:
+ c_element = c_node
+ if c_element.content:
+ _fixThreadDictPtr(&c_element.content.name, c_src_dict, c_dict)
+ _fixThreadDictPtr(&c_element.content.prefix, c_src_dict, c_dict)
+ c_attribute = c_element.attributes
+ while c_attribute:
+ if tree.LIBXML_VERSION < 21500:
+ # libxml2 2.15 no longer stores default values in the dict.
+ # See https://gitlab.gnome.org/GNOME/libxml2/-/commit/24628f25
+ _fixThreadDictPtr(&c_attribute.defaultValue, c_src_dict, c_dict)
+ _fixThreadDictPtr(&c_attribute.name, c_src_dict, c_dict)
+ _fixThreadDictPtr(&c_attribute.prefix, c_src_dict, c_dict)
+ _fixThreadDictPtr(&c_attribute.elem, c_src_dict, c_dict)
+ c_attribute = c_attribute.nexth
+ elif c_node.type == tree.XML_ENTITY_DECL:
+ c_entity = c_node
+ _fixThreadDictPtr(&c_entity.name, c_src_dict, c_dict)
+ _fixThreadDictPtr(&c_entity.ExternalID, c_src_dict, c_dict)
+ _fixThreadDictPtr(&c_entity.SystemID, c_src_dict, c_dict)
+ _fixThreadDictPtr(&c_entity.content, c_src_dict, c_dict)
+ c_node = c_node.next
+
+
+################################################################################
+# adopt an xmlDoc from an external libxml2 document source
+
+cdef _Document _adoptForeignDoc(xmlDoc* c_doc, _BaseParser parser=None, bint is_owned=True):
+ """Convert and wrap an externally produced xmlDoc for use in lxml.
+ Assures that all '_private' pointers are NULL to prevent accidental
+ dereference into lxml proxy objects.
+ """
+ if c_doc is NULL:
+ raise ValueError("Illegal document provided: NULL")
+ if c_doc.type not in (tree.XML_DOCUMENT_NODE, tree.XML_HTML_DOCUMENT_NODE):
+ doc_type = c_doc.type
+ if is_owned:
+ tree.xmlFreeDoc(c_doc)
+ raise ValueError(f"Illegal document provided: expected XML or HTML, found {doc_type}")
+
+ cdef xmlNode* c_node = c_doc
+
+ if is_owned:
+ tree.BEGIN_FOR_EACH_FROM(c_doc, c_node, 1)
+ c_node._private = NULL
+ tree.END_FOR_EACH_FROM(c_node)
+ else:
+ # create a fresh copy that lxml owns
+ c_doc = tree.xmlCopyDoc(c_doc, 1)
+ if c_doc is NULL:
+ raise MemoryError()
+
+ return _documentFactory(c_doc, parser)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/public-api.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/public-api.pxi
new file mode 100644
index 0000000000000000000000000000000000000000..fb8b2a2ced120b69c311270adba08924d65980a6
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/public-api.pxi
@@ -0,0 +1,178 @@
+# Public C API for lxml.etree
+
+cdef public api _Element deepcopyNodeToDocument(_Document doc, xmlNode* c_root):
+ "Recursively copy the element into the document. doc is not modified."
+ cdef xmlNode* c_node
+ c_node = _copyNodeToDoc(c_root, doc._c_doc)
+ return _elementFactory(doc, c_node)
+
+cdef public api _ElementTree elementTreeFactory(_Element context_node):
+ _assertValidNode(context_node)
+ return newElementTree(context_node, _ElementTree)
+
+cdef public api _ElementTree newElementTree(_Element context_node,
+ object subclass):
+ if context_node is NULL or context_node is None:
+ raise TypeError
+ _assertValidNode(context_node)
+ return _newElementTree(context_node._doc, context_node, subclass)
+
+cdef public api _ElementTree adoptExternalDocument(xmlDoc* c_doc, parser, bint is_owned):
+ if c_doc is NULL:
+ raise TypeError
+ doc = _adoptForeignDoc(c_doc, parser, is_owned)
+ return _elementTreeFactory(doc, None)
+
+cdef public api _Element elementFactory(_Document doc, xmlNode* c_node):
+ if c_node is NULL or doc is None:
+ raise TypeError
+ return _elementFactory(doc, c_node)
+
+cdef public api _Element makeElement(tag, _Document doc, parser,
+ text, tail, attrib, nsmap):
+ return _makeElement(tag, NULL, doc, parser, text, tail, attrib, nsmap, None)
+
+cdef public api _Element makeSubElement(_Element parent, tag, text, tail,
+ attrib, nsmap):
+ _assertValidNode(parent)
+ return _makeSubElement(parent, tag, text, tail, attrib, nsmap, None)
+
+cdef public api void setElementClassLookupFunction(
+ _element_class_lookup_function function, state):
+ _setElementClassLookupFunction(function, state)
+
+cdef public api object lookupDefaultElementClass(state, doc, xmlNode* c_node):
+ return _lookupDefaultElementClass(state, doc, c_node)
+
+cdef public api object lookupNamespaceElementClass(state, doc, xmlNode* c_node):
+ return _find_nselement_class(state, doc, c_node)
+
+cdef public api object callLookupFallback(FallbackElementClassLookup lookup,
+ _Document doc, xmlNode* c_node):
+ return _callLookupFallback(lookup, doc, c_node)
+
+cdef public api int tagMatches(xmlNode* c_node, const_xmlChar* c_href, const_xmlChar* c_name):
+ if c_node is NULL:
+ return -1
+ return _tagMatches(c_node, c_href, c_name)
+
+cdef public api _Document documentOrRaise(object input):
+ return _documentOrRaise(input)
+
+cdef public api _Element rootNodeOrRaise(object input):
+ return _rootNodeOrRaise(input)
+
+cdef public api bint hasText(xmlNode* c_node):
+ return _hasText(c_node)
+
+cdef public api bint hasTail(xmlNode* c_node):
+ return _hasTail(c_node)
+
+cdef public api unicode textOf(xmlNode* c_node):
+ if c_node is NULL:
+ return None
+ return _collectText(c_node.children)
+
+cdef public api unicode tailOf(xmlNode* c_node):
+ if c_node is NULL:
+ return None
+ return _collectText(c_node.next)
+
+cdef public api int setNodeText(xmlNode* c_node, text) except -1:
+ if c_node is NULL:
+ raise ValueError
+ return _setNodeText(c_node, text)
+
+cdef public api int setTailText(xmlNode* c_node, text) except -1:
+ if c_node is NULL:
+ raise ValueError
+ return _setTailText(c_node, text)
+
+cdef public api unicode attributeValue(xmlNode* c_element, xmlAttr* c_attrib_node):
+ return _attributeValue(c_element, c_attrib_node)
+
+cdef public api unicode attributeValueFromNsName(xmlNode* c_element,
+ const_xmlChar* ns, const_xmlChar* name):
+ return _attributeValueFromNsName(c_element, ns, name)
+
+cdef public api object getAttributeValue(_Element element, key, default):
+ _assertValidNode(element)
+ return _getAttributeValue(element, key, default)
+
+cdef public api object iterattributes(_Element element, int keysvalues):
+ _assertValidNode(element)
+ return _attributeIteratorFactory(element, keysvalues)
+
+cdef public api list collectAttributes(xmlNode* c_element, int keysvalues):
+ return _collectAttributes(c_element, keysvalues)
+
+cdef public api int setAttributeValue(_Element element, key, value) except -1:
+ _assertValidNode(element)
+ return _setAttributeValue(element, key, value)
+
+cdef public api int delAttribute(_Element element, key) except -1:
+ _assertValidNode(element)
+ return _delAttribute(element, key)
+
+cdef public api int delAttributeFromNsName(tree.xmlNode* c_element,
+ const_xmlChar* c_href, const_xmlChar* c_name):
+ return _delAttributeFromNsName(c_element, c_href, c_name)
+
+cdef public api bint hasChild(xmlNode* c_node):
+ return _hasChild(c_node)
+
+cdef public api xmlNode* findChild(xmlNode* c_node, Py_ssize_t index):
+ return _findChild(c_node, index)
+
+cdef public api xmlNode* findChildForwards(xmlNode* c_node, Py_ssize_t index):
+ return _findChildForwards(c_node, index)
+
+cdef public api xmlNode* findChildBackwards(xmlNode* c_node, Py_ssize_t index):
+ return _findChildBackwards(c_node, index)
+
+cdef public api xmlNode* nextElement(xmlNode* c_node):
+ return _nextElement(c_node)
+
+cdef public api xmlNode* previousElement(xmlNode* c_node):
+ return _previousElement(c_node)
+
+cdef public api void appendChild(_Element parent, _Element child):
+ # deprecated, use appendChildToElement() instead!
+ _appendChild(parent, child)
+
+cdef public api int appendChildToElement(_Element parent, _Element child) except -1:
+ return _appendChild(parent, child)
+
+cdef public api unicode pyunicode(const_xmlChar* s):
+ if s is NULL:
+ raise TypeError
+ return funicode(s)
+
+cdef public api bytes utf8(object s):
+ return _utf8(s)
+
+cdef public api tuple getNsTag(object tag):
+ return _getNsTag(tag)
+
+cdef public api tuple getNsTagWithEmptyNs(object tag):
+ return _getNsTagWithEmptyNs(tag)
+
+cdef public api unicode namespacedName(xmlNode* c_node):
+ return _namespacedName(c_node)
+
+cdef public api unicode namespacedNameFromNsName(const_xmlChar* href, const_xmlChar* name):
+ return _namespacedNameFromNsName(href, name)
+
+cdef public api void iteratorStoreNext(_ElementIterator iterator, _Element node):
+ # deprecated!
+ iterator._storeNext(node)
+
+cdef public api void initTagMatch(_ElementTagMatcher matcher, tag):
+ # deprecated!
+ matcher._initTagMatch(tag)
+
+cdef public api tree.xmlNs* findOrBuildNodeNsPrefix(
+ _Document doc, xmlNode* c_node, const_xmlChar* href, const_xmlChar* prefix) except NULL:
+ if doc is None:
+ raise TypeError
+ return doc._findOrBuildNodeNs(c_node, href, prefix, 0)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/pyclasslookup.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/pyclasslookup.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e1496dfb762108154a0c6c321a5e8fcf73de909
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/pyclasslookup.py
@@ -0,0 +1,3 @@
+# dummy module for backwards compatibility
+
+from lxml.etree import PythonElementClassLookup
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/relaxng.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/relaxng.pxi
new file mode 100644
index 0000000000000000000000000000000000000000..35f875891f7e59a785518b8b70bd19ef3f0f6099
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/relaxng.pxi
@@ -0,0 +1,165 @@
+# support for RelaxNG validation
+from lxml.includes cimport relaxng
+
+cdef object _rnc2rng
+try:
+ import rnc2rng as _rnc2rng
+except ImportError:
+ _rnc2rng = None
+
+
+cdef int _require_rnc2rng() except -1:
+ if _rnc2rng is None:
+ raise RelaxNGParseError(
+ 'compact syntax not supported (please install rnc2rng)')
+ return 0
+
+
+cdef class RelaxNGError(LxmlError):
+ """Base class for RelaxNG errors.
+ """
+
+cdef class RelaxNGParseError(RelaxNGError):
+ """Error while parsing an XML document as RelaxNG.
+ """
+
+cdef class RelaxNGValidateError(RelaxNGError):
+ """Error while validating an XML document with a RelaxNG schema.
+ """
+
+
+################################################################################
+# RelaxNG
+
+cdef class RelaxNG(_Validator):
+ """RelaxNG(self, etree=None, file=None)
+ Turn a document into a Relax NG validator.
+
+ Either pass a schema as Element or ElementTree, or pass a file or
+ filename through the ``file`` keyword argument.
+ """
+ cdef relaxng.xmlRelaxNG* _c_schema
+ def __cinit__(self):
+ self._c_schema = NULL
+
+ def __init__(self, etree=None, *, file=None):
+ cdef _Document doc
+ cdef _Element root_node
+ cdef xmlDoc* fake_c_doc = NULL
+ cdef relaxng.xmlRelaxNGParserCtxt* parser_ctxt = NULL
+ _Validator.__init__(self)
+ if etree is not None:
+ doc = _documentOrRaise(etree)
+ root_node = _rootNodeOrRaise(etree)
+ fake_c_doc = _fakeRootDoc(doc._c_doc, root_node._c_node)
+ parser_ctxt = relaxng.xmlRelaxNGNewDocParserCtxt(fake_c_doc)
+ elif file is not None:
+ if _isString(file):
+ if file[-4:].lower() == '.rnc':
+ _require_rnc2rng()
+ rng_data_utf8 = _utf8(_rnc2rng.dumps(_rnc2rng.load(file)))
+ doc = _parseMemoryDocument(rng_data_utf8, parser=None, url=file)
+ parser_ctxt = relaxng.xmlRelaxNGNewDocParserCtxt(doc._c_doc)
+ else:
+ doc = None
+ filename = _encodeFilename(file)
+ with self._error_log:
+ orig_loader = _register_document_loader()
+ parser_ctxt = relaxng.xmlRelaxNGNewParserCtxt(_cstr(filename))
+ _reset_document_loader(orig_loader)
+ elif (_getFilenameForFile(file) or '')[-4:].lower() == '.rnc':
+ _require_rnc2rng()
+ rng_data_utf8 = _utf8(_rnc2rng.dumps(_rnc2rng.load(file)))
+ doc = _parseMemoryDocument(
+ rng_data_utf8, parser=None, url=_getFilenameForFile(file))
+ parser_ctxt = relaxng.xmlRelaxNGNewDocParserCtxt(doc._c_doc)
+ else:
+ doc = _parseDocument(file, parser=None, base_url=None)
+ parser_ctxt = relaxng.xmlRelaxNGNewDocParserCtxt(doc._c_doc)
+ else:
+ raise RelaxNGParseError, "No tree or file given"
+
+ if parser_ctxt is NULL:
+ if fake_c_doc is not NULL:
+ _destroyFakeDoc(doc._c_doc, fake_c_doc)
+ raise RelaxNGParseError(
+ self._error_log._buildExceptionMessage(
+ "Document is not parsable as Relax NG"),
+ self._error_log)
+
+ # Need a cast here because older libxml2 releases do not use 'const' in the functype.
+ relaxng.xmlRelaxNGSetParserStructuredErrors(
+ parser_ctxt, _receiveError, self._error_log)
+ _connectGenericErrorLog(self._error_log, xmlerror.XML_FROM_RELAXNGP)
+ self._c_schema = relaxng.xmlRelaxNGParse(parser_ctxt)
+ _connectGenericErrorLog(None)
+
+ relaxng.xmlRelaxNGFreeParserCtxt(parser_ctxt)
+ if self._c_schema is NULL:
+ if fake_c_doc is not NULL:
+ _destroyFakeDoc(doc._c_doc, fake_c_doc)
+ raise RelaxNGParseError(
+ self._error_log._buildExceptionMessage(
+ "Document is not valid Relax NG"),
+ self._error_log)
+ if fake_c_doc is not NULL:
+ _destroyFakeDoc(doc._c_doc, fake_c_doc)
+
+ def __dealloc__(self):
+ relaxng.xmlRelaxNGFree(self._c_schema)
+
+ def __call__(self, etree):
+ """__call__(self, etree)
+
+ Validate doc using Relax NG.
+
+ Returns true if document is valid, false if not."""
+ cdef _Document doc
+ cdef _Element root_node
+ cdef xmlDoc* c_doc
+ cdef relaxng.xmlRelaxNGValidCtxt* valid_ctxt
+ cdef int ret
+
+ assert self._c_schema is not NULL, "RelaxNG instance not initialised"
+ doc = _documentOrRaise(etree)
+ root_node = _rootNodeOrRaise(etree)
+
+ valid_ctxt = relaxng.xmlRelaxNGNewValidCtxt(self._c_schema)
+ if valid_ctxt is NULL:
+ raise MemoryError()
+
+ try:
+ self._error_log.clear()
+ # Need a cast here because older libxml2 releases do not use 'const' in the functype.
+ relaxng.xmlRelaxNGSetValidStructuredErrors(
+ valid_ctxt, _receiveError, self._error_log)
+ _connectGenericErrorLog(self._error_log, xmlerror.XML_FROM_RELAXNGV)
+ c_doc = _fakeRootDoc(doc._c_doc, root_node._c_node)
+ with nogil:
+ ret = relaxng.xmlRelaxNGValidateDoc(valid_ctxt, c_doc)
+ _destroyFakeDoc(doc._c_doc, c_doc)
+ finally:
+ _connectGenericErrorLog(None)
+ relaxng.xmlRelaxNGFreeValidCtxt(valid_ctxt)
+
+ if ret == -1:
+ raise RelaxNGValidateError(
+ "Internal error in Relax NG validation",
+ self._error_log)
+ if ret == 0:
+ return True
+ else:
+ return False
+
+ @classmethod
+ def from_rnc_string(cls, src, base_url=None):
+ """Parse a RelaxNG schema in compact syntax from a text string
+
+ Requires the rnc2rng package to be installed.
+
+ Passing the source URL or file path of the source as 'base_url'
+ will enable resolving resource references relative to the source.
+ """
+ _require_rnc2rng()
+ rng_str = utf8(_rnc2rng.dumps(_rnc2rng.loads(src)))
+ return cls(_parseMemoryDocument(rng_str, parser=None, url=base_url))
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/sax.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/sax.py
new file mode 100644
index 0000000000000000000000000000000000000000..db77f6f29705e0776471c013abd0e1f97c18a457
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/sax.py
@@ -0,0 +1,285 @@
+"""
+SAX-based adapter to copy trees from/to the Python standard library.
+
+Use the `ElementTreeContentHandler` class to build an ElementTree from
+SAX events.
+
+Use the `ElementTreeProducer` class or the `saxify()` function to fire
+the SAX events of an ElementTree against a SAX ContentHandler.
+
+See https://lxml.de/sax.html
+"""
+
+
+from xml.sax.handler import ContentHandler
+from lxml import etree
+from lxml.etree import ElementTree, SubElement
+from lxml.etree import Comment, ProcessingInstruction
+
+try:
+ from types import GenericAlias as _GenericAlias
+except ImportError:
+ # Python 3.8 - we only need this as return value from "__class_getitem__"
+ def _GenericAlias(cls, item):
+ return f"{cls.__name__}[{item.__name__}]"
+
+
+class SaxError(etree.LxmlError):
+ """General SAX error.
+ """
+
+
+def _getNsTag(tag):
+ if tag[0] == '{' and '}' in tag:
+ return tuple(tag[1:].split('}', 1))
+ else:
+ return None, tag
+
+
+class ElementTreeContentHandler(ContentHandler):
+ """Build an lxml ElementTree from SAX events.
+ """
+ def __init__(self, makeelement=None):
+ ContentHandler.__init__(self)
+ self._root = None
+ self._root_siblings = []
+ self._element_stack = []
+ self._default_ns = None
+ self._ns_mapping = { None : [None] }
+ self._new_mappings = {}
+ if makeelement is None:
+ makeelement = etree.Element
+ self._makeelement = makeelement
+
+ def _get_etree(self):
+ "Contains the generated ElementTree after parsing is finished."
+ return ElementTree(self._root)
+
+ etree = property(_get_etree, doc=_get_etree.__doc__)
+
+ def setDocumentLocator(self, locator):
+ pass
+
+ def startDocument(self):
+ pass
+
+ def endDocument(self):
+ pass
+
+ def startPrefixMapping(self, prefix, uri):
+ self._new_mappings[prefix] = uri
+ try:
+ self._ns_mapping[prefix].append(uri)
+ except KeyError:
+ self._ns_mapping[prefix] = [uri]
+ if prefix is None:
+ self._default_ns = uri
+
+ def endPrefixMapping(self, prefix):
+ ns_uri_list = self._ns_mapping[prefix]
+ ns_uri_list.pop()
+ if prefix is None:
+ self._default_ns = ns_uri_list[-1]
+
+ def _buildTag(self, ns_name_tuple):
+ ns_uri, local_name = ns_name_tuple
+ if ns_uri:
+ el_tag = "{%s}%s" % ns_name_tuple
+ elif self._default_ns:
+ el_tag = "{%s}%s" % (self._default_ns, local_name)
+ else:
+ el_tag = local_name
+ return el_tag
+
+ def startElementNS(self, ns_name, qname, attributes=None):
+ el_name = self._buildTag(ns_name)
+ if attributes:
+ attrs = {}
+ try:
+ iter_attributes = attributes.iteritems()
+ except AttributeError:
+ iter_attributes = attributes.items()
+
+ for name_tuple, value in iter_attributes:
+ if name_tuple[0]:
+ attr_name = "{%s}%s" % name_tuple
+ else:
+ attr_name = name_tuple[1]
+ attrs[attr_name] = value
+ else:
+ attrs = None
+
+ element_stack = self._element_stack
+ if self._root is None:
+ element = self._root = \
+ self._makeelement(el_name, attrs, self._new_mappings)
+ if self._root_siblings and hasattr(element, 'addprevious'):
+ for sibling in self._root_siblings:
+ element.addprevious(sibling)
+ del self._root_siblings[:]
+ else:
+ element = SubElement(element_stack[-1], el_name,
+ attrs, self._new_mappings)
+ element_stack.append(element)
+
+ self._new_mappings.clear()
+
+ def processingInstruction(self, target, data):
+ pi = ProcessingInstruction(target, data)
+ if self._root is None:
+ self._root_siblings.append(pi)
+ else:
+ self._element_stack[-1].append(pi)
+
+ def endElementNS(self, ns_name, qname):
+ element = self._element_stack.pop()
+ el_tag = self._buildTag(ns_name)
+ if el_tag != element.tag:
+ raise SaxError("Unexpected element closed: " + el_tag)
+
+ def startElement(self, name, attributes=None):
+ if attributes:
+ attributes = {(None, k): v for k, v in attributes.items()}
+ self.startElementNS((None, name), name, attributes)
+
+ def endElement(self, name):
+ self.endElementNS((None, name), name)
+
+ def characters(self, data):
+ last_element = self._element_stack[-1]
+ try:
+ # if there already is a child element, we must append to its tail
+ last_element = last_element[-1]
+ except IndexError:
+ # otherwise: append to the text
+ last_element.text = (last_element.text or '') + data
+ else:
+ last_element.tail = (last_element.tail or '') + data
+
+ ignorableWhitespace = characters
+
+ # Allow subscripting sax.ElementTreeContentHandler in type annotions (PEP 560)
+ def __class_getitem__(cls, item):
+ return _GenericAlias(cls, item)
+
+
+class ElementTreeProducer:
+ """Produces SAX events for an element and children.
+ """
+ def __init__(self, element_or_tree, content_handler):
+ try:
+ element = element_or_tree.getroot()
+ except AttributeError:
+ element = element_or_tree
+ self._element = element
+ self._content_handler = content_handler
+ from xml.sax.xmlreader import AttributesNSImpl as attr_class
+ self._attr_class = attr_class
+ self._empty_attributes = attr_class({}, {})
+
+ def saxify(self):
+ self._content_handler.startDocument()
+
+ element = self._element
+ if hasattr(element, 'getprevious'):
+ siblings = []
+ sibling = element.getprevious()
+ while getattr(sibling, 'tag', None) is ProcessingInstruction:
+ siblings.append(sibling)
+ sibling = sibling.getprevious()
+ for sibling in siblings[::-1]:
+ self._recursive_saxify(sibling, {})
+
+ self._recursive_saxify(element, {})
+
+ if hasattr(element, 'getnext'):
+ sibling = element.getnext()
+ while getattr(sibling, 'tag', None) is ProcessingInstruction:
+ self._recursive_saxify(sibling, {})
+ sibling = sibling.getnext()
+
+ self._content_handler.endDocument()
+
+ def _recursive_saxify(self, element, parent_nsmap):
+ content_handler = self._content_handler
+ tag = element.tag
+ if tag is Comment or tag is ProcessingInstruction:
+ if tag is ProcessingInstruction:
+ content_handler.processingInstruction(
+ element.target, element.text)
+ tail = element.tail
+ if tail:
+ content_handler.characters(tail)
+ return
+
+ element_nsmap = element.nsmap
+ new_prefixes = []
+ if element_nsmap != parent_nsmap:
+ # There have been updates to the namespace
+ for prefix, ns_uri in element_nsmap.items():
+ if parent_nsmap.get(prefix) != ns_uri:
+ new_prefixes.append( (prefix, ns_uri) )
+
+ attribs = element.items()
+ if attribs:
+ attr_values = {}
+ attr_qnames = {}
+ for attr_ns_name, value in attribs:
+ attr_ns_tuple = _getNsTag(attr_ns_name)
+ attr_values[attr_ns_tuple] = value
+ attr_qnames[attr_ns_tuple] = self._build_qname(
+ attr_ns_tuple[0], attr_ns_tuple[1], element_nsmap,
+ preferred_prefix=None, is_attribute=True)
+ sax_attributes = self._attr_class(attr_values, attr_qnames)
+ else:
+ sax_attributes = self._empty_attributes
+
+ ns_uri, local_name = _getNsTag(tag)
+ qname = self._build_qname(
+ ns_uri, local_name, element_nsmap, element.prefix, is_attribute=False)
+
+ for prefix, uri in new_prefixes:
+ content_handler.startPrefixMapping(prefix, uri)
+ content_handler.startElementNS(
+ (ns_uri, local_name), qname, sax_attributes)
+ text = element.text
+ if text:
+ content_handler.characters(text)
+ for child in element:
+ self._recursive_saxify(child, element_nsmap)
+ content_handler.endElementNS((ns_uri, local_name), qname)
+ for prefix, uri in new_prefixes:
+ content_handler.endPrefixMapping(prefix)
+ tail = element.tail
+ if tail:
+ content_handler.characters(tail)
+
+ def _build_qname(self, ns_uri, local_name, nsmap, preferred_prefix, is_attribute):
+ if ns_uri is None:
+ return local_name
+
+ if not is_attribute and nsmap.get(preferred_prefix) == ns_uri:
+ prefix = preferred_prefix
+ else:
+ # Pick the first matching prefix, in alphabetical order.
+ candidates = [
+ pfx for (pfx, uri) in nsmap.items()
+ if pfx is not None and uri == ns_uri
+ ]
+ prefix = (
+ candidates[0] if len(candidates) == 1
+ else min(candidates) if candidates
+ else None
+ )
+
+ if prefix is None:
+ # Default namespace
+ return local_name
+ return prefix + ':' + local_name
+
+
+def saxify(element_or_tree, content_handler):
+ """One-shot helper to generate SAX events from an XML tree and fire
+ them against a SAX ContentHandler.
+ """
+ return ElementTreeProducer(element_or_tree, content_handler).saxify()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/schematron.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/schematron.pxi
new file mode 100644
index 0000000000000000000000000000000000000000..650e34b2b4c562517479b56a8f6f45b7111efafd
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/schematron.pxi
@@ -0,0 +1,173 @@
+# support for Schematron validation
+from lxml.includes cimport schematron
+
+
+cdef class SchematronError(LxmlError):
+ """Base class of all Schematron errors.
+ """
+
+cdef class SchematronParseError(SchematronError):
+ """Error while parsing an XML document as Schematron schema.
+ """
+
+cdef class SchematronValidateError(SchematronError):
+ """Error while validating an XML document with a Schematron schema.
+ """
+
+
+################################################################################
+# Schematron
+
+cdef class Schematron(_Validator):
+ """Schematron(self, etree=None, file=None)
+ A Schematron validator.
+
+ Pass a root Element or an ElementTree to turn it into a validator.
+ Alternatively, pass a filename as keyword argument 'file' to parse from
+ the file system.
+
+ Schematron is a less well known, but very powerful schema language. The main
+ idea is to use the capabilities of XPath to put restrictions on the structure
+ and the content of XML documents. Here is a simple example::
+
+ >>> schematron = Schematron(XML('''
+ ...
+ ...
+ ...
+ ... Attribute
+ ... is forbidden
+ ...
+ ...
+ ...
+ ...
+ ... '''))
+
+ >>> xml = XML('''
+ ...
+ ...
+ ...
+ ...
+ ... ''')
+
+ >>> schematron.validate(xml)
+ 0
+
+ >>> xml = XML('''
+ ...
+ ...
+ ...
+ ...
+ ... ''')
+
+ >>> schematron.validate(xml)
+ 1
+
+ Schematron was added to libxml2 in version 2.6.21. Before version 2.6.32,
+ however, Schematron lacked support for error reporting other than to stderr.
+ This version is therefore required to retrieve validation warnings and
+ errors in lxml.
+ """
+ cdef schematron.xmlSchematron* _c_schema
+ cdef xmlDoc* _c_schema_doc
+
+ def __init__(self, etree=None, *, file=None):
+ cdef _Document doc
+ cdef _Element root_node
+ cdef xmlNode* c_node
+ cdef char* c_href
+ cdef schematron.xmlSchematronParserCtxt* parser_ctxt = NULL
+ _Validator.__init__(self)
+ if not config.ENABLE_SCHEMATRON:
+ raise SchematronError, \
+ "lxml.etree was compiled without Schematron support."
+
+ import warnings
+ warnings.warn(
+ "The (non-ISO) Schematron feature is deprecated and will be removed from libxml2 and lxml. "
+ "Use 'lxml.isoschematron' instead.",
+ DeprecationWarning,
+ )
+
+ if etree is not None:
+ doc = _documentOrRaise(etree)
+ root_node = _rootNodeOrRaise(etree)
+ self._c_schema_doc = _copyDocRoot(doc._c_doc, root_node._c_node)
+ parser_ctxt = schematron.xmlSchematronNewDocParserCtxt(self._c_schema_doc)
+ elif file is not None:
+ filename = _getFilenameForFile(file)
+ if filename is None:
+ # XXX assume a string object
+ filename = file
+ filename = _encodeFilename(filename)
+ with self._error_log:
+ orig_loader = _register_document_loader()
+ parser_ctxt = schematron.xmlSchematronNewParserCtxt(_cstr(filename))
+ _reset_document_loader(orig_loader)
+ else:
+ raise SchematronParseError, "No tree or file given"
+
+ if parser_ctxt is NULL:
+ if self._c_schema_doc is not NULL:
+ tree.xmlFreeDoc(self._c_schema_doc)
+ self._c_schema_doc = NULL
+ raise MemoryError()
+
+ try:
+ with self._error_log:
+ orig_loader = _register_document_loader()
+ self._c_schema = schematron.xmlSchematronParse(parser_ctxt)
+ _reset_document_loader(orig_loader)
+ finally:
+ schematron.xmlSchematronFreeParserCtxt(parser_ctxt)
+
+ if self._c_schema is NULL:
+ raise SchematronParseError(
+ "Document is not a valid Schematron schema",
+ self._error_log)
+
+ def __dealloc__(self):
+ schematron.xmlSchematronFree(self._c_schema)
+ if self._c_schema_doc is not NULL:
+ tree.xmlFreeDoc(self._c_schema_doc)
+
+ def __call__(self, etree):
+ """__call__(self, etree)
+
+ Validate doc using Schematron.
+
+ Returns true if document is valid, false if not."""
+ cdef _Document doc
+ cdef _Element root_node
+ cdef xmlDoc* c_doc
+ cdef schematron.xmlSchematronValidCtxt* valid_ctxt
+ cdef int ret
+
+ assert self._c_schema is not NULL, "Schematron instance not initialised"
+ doc = _documentOrRaise(etree)
+ root_node = _rootNodeOrRaise(etree)
+
+ valid_ctxt = schematron.xmlSchematronNewValidCtxt(
+ self._c_schema, schematron.XML_SCHEMATRON_OUT_ERROR)
+ if valid_ctxt is NULL:
+ raise MemoryError()
+
+ try:
+ self._error_log.clear()
+ # Need a cast here because older libxml2 releases do not use 'const' in the functype.
+ schematron.xmlSchematronSetValidStructuredErrors(
+ valid_ctxt, _receiveError, self._error_log)
+ c_doc = _fakeRootDoc(doc._c_doc, root_node._c_node)
+ with nogil:
+ ret = schematron.xmlSchematronValidateDoc(valid_ctxt, c_doc)
+ _destroyFakeDoc(doc._c_doc, c_doc)
+ finally:
+ schematron.xmlSchematronFreeValidCtxt(valid_ctxt)
+
+ if ret == -1:
+ raise SchematronValidateError(
+ "Internal error in Schematron validation",
+ self._error_log)
+ if ret == 0:
+ return True
+ else:
+ return False
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/serializer.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/serializer.pxi
new file mode 100644
index 0000000000000000000000000000000000000000..5266bdf2bdc71a0bbdbe0c8702dbc43f5684ee28
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/serializer.pxi
@@ -0,0 +1,1849 @@
+# XML serialization and output functions
+
+cdef object GzipFile
+from gzip import GzipFile
+
+
+cdef class SerialisationError(LxmlError):
+ """A libxml2 error that occurred during serialisation.
+ """
+
+
+cdef enum _OutputMethods:
+ OUTPUT_METHOD_XML
+ OUTPUT_METHOD_HTML
+ OUTPUT_METHOD_TEXT
+
+
+cdef int _findOutputMethod(method) except -1:
+ if method is None:
+ return OUTPUT_METHOD_XML
+ method = method.lower()
+ if method == "xml":
+ return OUTPUT_METHOD_XML
+ if method == "html":
+ return OUTPUT_METHOD_HTML
+ if method == "text":
+ return OUTPUT_METHOD_TEXT
+ raise ValueError(f"unknown output method {method!r}")
+
+
+cdef _textToString(xmlNode* c_node, encoding, bint with_tail):
+ cdef bint needs_conversion
+ cdef const_xmlChar* c_text
+ cdef xmlNode* c_text_node
+ cdef tree.xmlBuffer* c_buffer
+ cdef int error_result
+
+ c_buffer = tree.xmlBufferCreate()
+ if c_buffer is NULL:
+ raise MemoryError()
+
+ with nogil:
+ error_result = tree.xmlNodeBufGetContent(c_buffer, c_node)
+ if with_tail:
+ c_text_node = _textNodeOrSkip(c_node.next)
+ while c_text_node is not NULL:
+ tree.xmlBufferWriteChar(c_buffer, c_text_node.content)
+ c_text_node = _textNodeOrSkip(c_text_node.next)
+ c_text = tree.xmlBufferContent(c_buffer)
+
+ if error_result < 0 or c_text is NULL:
+ tree.xmlBufferFree(c_buffer)
+ raise SerialisationError, "Error during serialisation (out of memory?)"
+
+ try:
+ needs_conversion = 0
+ if encoding is unicode:
+ needs_conversion = 1
+ elif encoding is not None:
+ # Python prefers lower case encoding names
+ encoding = encoding.lower()
+ if encoding not in ('utf8', 'utf-8'):
+ if encoding == 'ascii':
+ if isutf8l(c_text, tree.xmlBufferLength(c_buffer)):
+ # will raise a decode error below
+ needs_conversion = 1
+ else:
+ needs_conversion = 1
+
+ if needs_conversion:
+ text = (c_text)[:tree.xmlBufferLength(c_buffer)].decode('utf8')
+ if encoding is not unicode:
+ encoding = _utf8(encoding)
+ text = python.PyUnicode_AsEncodedString(
+ text, encoding, 'strict')
+ else:
+ text = (c_text)[:tree.xmlBufferLength(c_buffer)]
+ finally:
+ tree.xmlBufferFree(c_buffer)
+ return text
+
+
+cdef _tostring(_Element element, encoding, doctype, method,
+ bint write_xml_declaration, bint write_complete_document,
+ bint pretty_print, bint with_tail, int standalone):
+ """Serialize an element to an encoded string representation of its XML
+ tree.
+ """
+ cdef tree.xmlOutputBuffer* c_buffer
+ cdef tree.xmlBuf* c_result_buffer
+ cdef tree.xmlCharEncodingHandler* enchandler
+ cdef const_char* c_enc
+ cdef const_xmlChar* c_version
+ cdef const_xmlChar* c_doctype
+ cdef int c_method
+ cdef int error_result
+ if element is None:
+ return None
+ _assertValidNode(element)
+ c_method = _findOutputMethod(method)
+ if c_method == OUTPUT_METHOD_TEXT:
+ return _textToString(element._c_node, encoding, with_tail)
+ if encoding is None or encoding is unicode:
+ c_enc = NULL
+ else:
+ encoding = _utf8(encoding)
+ c_enc = _cstr(encoding)
+ if doctype is None:
+ c_doctype = NULL
+ else:
+ doctype = _utf8(doctype)
+ c_doctype = _xcstr(doctype)
+ # it is necessary to *and* find the encoding handler *and* use
+ # encoding during output
+ enchandler = tree.xmlFindCharEncodingHandler(c_enc)
+ if enchandler is NULL and c_enc is not NULL:
+ if encoding is not None:
+ encoding = encoding.decode('UTF-8')
+ raise LookupError, f"unknown encoding: '{encoding}'"
+ c_buffer = tree.xmlAllocOutputBuffer(enchandler)
+ if c_buffer is NULL:
+ tree.xmlCharEncCloseFunc(enchandler)
+ raise MemoryError()
+
+ with nogil:
+ _writeNodeToBuffer(c_buffer, element._c_node, c_enc, c_doctype, c_method,
+ write_xml_declaration, write_complete_document,
+ pretty_print, with_tail, standalone)
+ tree.xmlOutputBufferFlush(c_buffer)
+ if c_buffer.conv is not NULL:
+ c_result_buffer = c_buffer.conv
+ else:
+ c_result_buffer = c_buffer.buffer
+
+ error_result = c_buffer.error
+ if error_result != xmlerror.XML_ERR_OK:
+ tree.xmlOutputBufferClose(c_buffer)
+ _raiseSerialisationError(error_result)
+
+ try:
+ if encoding is unicode:
+ result = (tree.xmlBufContent(
+ c_result_buffer))[:tree.xmlBufUse(c_result_buffer)].decode('UTF-8')
+ else:
+ result = (tree.xmlBufContent(
+ c_result_buffer))[:tree.xmlBufUse(c_result_buffer)]
+ finally:
+ error_result = tree.xmlOutputBufferClose(c_buffer)
+ if error_result == -1:
+ _raiseSerialisationError(error_result)
+ return result
+
+cdef bytes _tostringC14N(element_or_tree, bint exclusive, bint with_comments, inclusive_ns_prefixes):
+ cdef xmlDoc* c_doc
+ cdef xmlChar* c_buffer = NULL
+ cdef int byte_count = -1
+ cdef bytes result
+ cdef _Document doc
+ cdef _Element element
+ cdef xmlChar **c_inclusive_ns_prefixes
+
+ if isinstance(element_or_tree, _Element):
+ _assertValidNode(<_Element>element_or_tree)
+ doc = (<_Element>element_or_tree)._doc
+ c_doc = _plainFakeRootDoc(doc._c_doc, (<_Element>element_or_tree)._c_node, 0)
+ else:
+ doc = _documentOrRaise(element_or_tree)
+ _assertValidDoc(doc)
+ c_doc = doc._c_doc
+
+ c_inclusive_ns_prefixes = _convert_ns_prefixes(c_doc.dict, inclusive_ns_prefixes) if inclusive_ns_prefixes else NULL
+ try:
+ with nogil:
+ byte_count = c14n.xmlC14NDocDumpMemory(
+ c_doc, NULL, exclusive, c_inclusive_ns_prefixes, with_comments, &c_buffer)
+
+ finally:
+ _destroyFakeDoc(doc._c_doc, c_doc)
+ if c_inclusive_ns_prefixes is not NULL:
+ python.lxml_free(c_inclusive_ns_prefixes)
+
+ if byte_count < 0 or c_buffer is NULL:
+ if c_buffer is not NULL:
+ tree.xmlFree(c_buffer)
+ raise C14NError, "C14N failed"
+ try:
+ result = c_buffer[:byte_count]
+ finally:
+ tree.xmlFree(c_buffer)
+ return result
+
+cdef _raiseSerialisationError(int error_result):
+ if error_result == xmlerror.XML_ERR_NO_MEMORY:
+ raise MemoryError()
+ message = ErrorTypes._getName(error_result)
+ if message is None:
+ message = f"unknown error {error_result}"
+ raise SerialisationError, message
+
+############################################################
+# low-level serialisation functions
+
+cdef void _writeDoctype(tree.xmlOutputBuffer* c_buffer,
+ const_xmlChar* c_doctype) noexcept nogil:
+ tree.xmlOutputBufferWrite(c_buffer, tree.xmlStrlen(c_doctype),
+ c_doctype)
+ tree.xmlOutputBufferWriteString(c_buffer, "\n")
+
+cdef void _writeNodeToBuffer(tree.xmlOutputBuffer* c_buffer,
+ xmlNode* c_node, const_char* encoding, const_xmlChar* c_doctype,
+ int c_method, bint write_xml_declaration,
+ bint write_complete_document,
+ bint pretty_print, bint with_tail,
+ int standalone) noexcept nogil:
+ cdef xmlNode* c_nsdecl_node
+ cdef xmlDoc* c_doc = c_node.doc
+ if write_xml_declaration and c_method == OUTPUT_METHOD_XML:
+ _writeDeclarationToBuffer(c_buffer, c_doc.version, encoding, standalone)
+
+ # comments/processing instructions before doctype declaration
+ if write_complete_document and not c_buffer.error and c_doc.intSubset:
+ _writePrevSiblings(c_buffer, c_doc.intSubset, encoding, pretty_print)
+
+ if c_doctype:
+ _writeDoctype(c_buffer, c_doctype)
+ # write internal DTD subset, preceding PIs/comments, etc.
+ if write_complete_document and not c_buffer.error:
+ if c_doctype is NULL:
+ _writeDtdToBuffer(c_buffer, c_doc, c_node.name, c_method, encoding)
+ _writePrevSiblings(c_buffer, c_node, encoding, pretty_print)
+
+ c_nsdecl_node = c_node
+ if not c_node.parent or c_node.parent.type != tree.XML_DOCUMENT_NODE:
+ # copy the node and add namespaces from parents
+ # this is required to make libxml write them
+ c_nsdecl_node = tree.xmlCopyNode(c_node, 2)
+ if not c_nsdecl_node:
+ c_buffer.error = xmlerror.XML_ERR_NO_MEMORY
+ return
+ _copyParentNamespaces(c_node, c_nsdecl_node)
+
+ c_nsdecl_node.parent = c_node.parent
+ c_nsdecl_node.children = c_node.children
+ c_nsdecl_node.last = c_node.last
+
+ # write node
+ if c_method == OUTPUT_METHOD_HTML:
+ tree.htmlNodeDumpFormatOutput(
+ c_buffer, c_doc, c_nsdecl_node, encoding, pretty_print)
+ else:
+ tree.xmlNodeDumpOutput(
+ c_buffer, c_doc, c_nsdecl_node, 0, pretty_print, encoding)
+
+ if c_nsdecl_node is not c_node:
+ # clean up
+ c_nsdecl_node.children = c_nsdecl_node.last = NULL
+ tree.xmlFreeNode(c_nsdecl_node)
+
+ if c_buffer.error:
+ return
+
+ # write tail, trailing comments, etc.
+ if with_tail:
+ _writeTail(c_buffer, c_node, encoding, c_method, pretty_print)
+ if write_complete_document:
+ _writeNextSiblings(c_buffer, c_node, encoding, pretty_print)
+ if pretty_print:
+ tree.xmlOutputBufferWrite(c_buffer, 1, "\n")
+
+cdef void _writeDeclarationToBuffer(tree.xmlOutputBuffer* c_buffer,
+ const_xmlChar* version, const_char* encoding,
+ int standalone) noexcept nogil:
+ if version is NULL:
+ version = "1.0"
+ tree.xmlOutputBufferWrite(c_buffer, 15, "version)
+ tree.xmlOutputBufferWrite(c_buffer, 12, "' encoding='")
+ tree.xmlOutputBufferWriteString(c_buffer, encoding)
+ if standalone == 0:
+ tree.xmlOutputBufferWrite(c_buffer, 20, "' standalone='no'?>\n")
+ elif standalone == 1:
+ tree.xmlOutputBufferWrite(c_buffer, 21, "' standalone='yes'?>\n")
+ else:
+ tree.xmlOutputBufferWrite(c_buffer, 4, "'?>\n")
+
+cdef void _writeDtdToBuffer(tree.xmlOutputBuffer* c_buffer,
+ xmlDoc* c_doc, const_xmlChar* c_root_name,
+ int c_method, const_char* encoding) noexcept nogil:
+ cdef tree.xmlDtd* c_dtd
+ cdef xmlNode* c_node
+ cdef char* quotechar
+ c_dtd = c_doc.intSubset
+ if not c_dtd or not c_dtd.name:
+ return
+
+ # Name in document type declaration must match the root element tag.
+ # For XML, case sensitive match, for HTML insensitive.
+ if c_method == OUTPUT_METHOD_HTML:
+ if tree.xmlStrcasecmp(c_root_name, c_dtd.name) != 0:
+ return
+ else:
+ if tree.xmlStrcmp(c_root_name, c_dtd.name) != 0:
+ return
+
+ tree.xmlOutputBufferWrite(c_buffer, 10, "c_dtd.name)
+
+ cdef const_xmlChar* public_id = c_dtd.ExternalID
+ cdef const_xmlChar* sys_url = c_dtd.SystemID
+ if public_id and public_id[0] == b'\0':
+ public_id = NULL
+ if sys_url and sys_url[0] == b'\0':
+ sys_url = NULL
+
+ if public_id:
+ tree.xmlOutputBufferWrite(c_buffer, 9, ' PUBLIC "')
+ tree.xmlOutputBufferWriteString(c_buffer, public_id)
+ if sys_url:
+ tree.xmlOutputBufferWrite(c_buffer, 2, '" ')
+ else:
+ tree.xmlOutputBufferWrite(c_buffer, 1, '"')
+ elif sys_url:
+ tree.xmlOutputBufferWrite(c_buffer, 8, ' SYSTEM ')
+
+ if sys_url:
+ if tree.xmlStrchr(sys_url, b'"'):
+ quotechar = '\''
+ else:
+ quotechar = '"'
+ tree.xmlOutputBufferWrite(c_buffer, 1, quotechar)
+ tree.xmlOutputBufferWriteString(c_buffer, sys_url)
+ tree.xmlOutputBufferWrite(c_buffer, 1, quotechar)
+
+ if (not c_dtd.entities and not c_dtd.elements and
+ not c_dtd.attributes and not c_dtd.notations and
+ not c_dtd.pentities):
+ tree.xmlOutputBufferWrite(c_buffer, 2, '>\n')
+ return
+
+ tree.xmlOutputBufferWrite(c_buffer, 3, ' [\n')
+ if c_dtd.notations and not c_buffer.error:
+ c_buf = tree.xmlBufferCreate()
+ if not c_buf:
+ c_buffer.error = xmlerror.XML_ERR_NO_MEMORY
+ return
+ tree.xmlDumpNotationTable(c_buf, c_dtd.notations)
+ tree.xmlOutputBufferWrite(
+ c_buffer, tree.xmlBufferLength(c_buf),
+ tree.xmlBufferContent(c_buf))
+ tree.xmlBufferFree(c_buf)
+ c_node = c_dtd.children
+ while c_node and not c_buffer.error:
+ tree.xmlNodeDumpOutput(c_buffer, c_node.doc, c_node, 0, 0, encoding)
+ c_node = c_node.next
+ tree.xmlOutputBufferWrite(c_buffer, 3, "]>\n")
+
+cdef void _writeTail(tree.xmlOutputBuffer* c_buffer, xmlNode* c_node,
+ const_char* encoding, int c_method, bint pretty_print) noexcept nogil:
+ "Write the element tail."
+ c_node = c_node.next
+ while c_node and not c_buffer.error and c_node.type in (
+ tree.XML_TEXT_NODE, tree.XML_CDATA_SECTION_NODE):
+ if c_method == OUTPUT_METHOD_HTML:
+ tree.htmlNodeDumpFormatOutput(
+ c_buffer, c_node.doc, c_node, encoding, pretty_print)
+ else:
+ tree.xmlNodeDumpOutput(
+ c_buffer, c_node.doc, c_node, 0, pretty_print, encoding)
+ c_node = c_node.next
+
+cdef void _writePrevSiblings(tree.xmlOutputBuffer* c_buffer, xmlNode* c_node,
+ const_char* encoding, bint pretty_print) noexcept nogil:
+ cdef xmlNode* c_sibling
+ if c_node.parent and _isElement(c_node.parent):
+ return
+ # we are at a root node, so add PI and comment siblings
+ c_sibling = c_node
+ while c_sibling.prev and \
+ (c_sibling.prev.type == tree.XML_PI_NODE or
+ c_sibling.prev.type == tree.XML_COMMENT_NODE):
+ c_sibling = c_sibling.prev
+ while c_sibling is not c_node and not c_buffer.error:
+ tree.xmlNodeDumpOutput(c_buffer, c_node.doc, c_sibling, 0,
+ pretty_print, encoding)
+ if pretty_print:
+ tree.xmlOutputBufferWriteString(c_buffer, "\n")
+ c_sibling = c_sibling.next
+
+cdef void _writeNextSiblings(tree.xmlOutputBuffer* c_buffer, xmlNode* c_node,
+ const_char* encoding, bint pretty_print) noexcept nogil:
+ cdef xmlNode* c_sibling
+ if c_node.parent and _isElement(c_node.parent):
+ return
+ # we are at a root node, so add PI and comment siblings
+ c_sibling = c_node.next
+ while not c_buffer.error and c_sibling and \
+ (c_sibling.type == tree.XML_PI_NODE or
+ c_sibling.type == tree.XML_COMMENT_NODE):
+ if pretty_print:
+ tree.xmlOutputBufferWriteString(c_buffer, "\n")
+ tree.xmlNodeDumpOutput(c_buffer, c_node.doc, c_sibling, 0,
+ pretty_print, encoding)
+ c_sibling = c_sibling.next
+
+
+# copied and adapted from libxml2 (xmlBufAttrSerializeTxtContent())
+cdef _write_attr_string(tree.xmlOutputBuffer* buf, const char *string):
+ cdef const char *base
+ cdef const char *cur
+
+ if string == NULL:
+ return
+
+ base = cur = string
+ while cur[0] != 0:
+ if cur[0] == b'\n':
+ if base != cur:
+ tree.xmlOutputBufferWrite(buf, cur - base, base)
+
+ tree.xmlOutputBufferWrite(buf, 5, "
")
+ cur += 1
+ base = cur
+
+ elif cur[0] == b'\r':
+ if base != cur:
+ tree.xmlOutputBufferWrite(buf, cur - base, base)
+
+ tree.xmlOutputBufferWrite(buf, 5, "
")
+ cur += 1
+ base = cur
+
+ elif cur[0] == b'\t':
+ if base != cur:
+ tree.xmlOutputBufferWrite(buf, cur - base, base)
+
+ tree.xmlOutputBufferWrite(buf, 4, " ")
+ cur += 1
+ base = cur
+
+ elif cur[0] == b'"':
+ if base != cur:
+ tree.xmlOutputBufferWrite(buf, cur - base, base)
+
+ tree.xmlOutputBufferWrite(buf, 6, """)
+ cur += 1
+ base = cur
+
+ elif cur[0] == b'<':
+ if base != cur:
+ tree.xmlOutputBufferWrite(buf, cur - base, base)
+
+ tree.xmlOutputBufferWrite(buf, 4, "<")
+ cur += 1
+ base = cur
+
+ elif cur[0] == b'>':
+ if base != cur:
+ tree.xmlOutputBufferWrite(buf, cur - base, base)
+
+ tree.xmlOutputBufferWrite(buf, 4, ">")
+ cur += 1
+ base = cur
+ elif cur[0] == b'&':
+ if base != cur:
+ tree.xmlOutputBufferWrite(buf, cur - base, base)
+
+ tree.xmlOutputBufferWrite(buf, 5, "&")
+ cur += 1
+ base = cur
+
+ else:
+ # Leave further encoding and escaping to the buffer encoder.
+ cur += 1
+
+ if base != cur:
+ tree.xmlOutputBufferWrite(buf, cur - base, base)
+
+
+cdef void _write_cdata_section(tree.xmlOutputBuffer* buf, const char* c_data, const char* c_end):
+ tree.xmlOutputBufferWrite(buf, 9, " limits.INT_MAX:
+ tree.xmlOutputBufferWrite(buf, limits.INT_MAX, c_data)
+ c_data += limits.INT_MAX
+ tree.xmlOutputBufferWrite(buf, c_end - c_data, c_data)
+ tree.xmlOutputBufferWrite(buf, 3, "]]>")
+
+
+cdef _write_cdata_string(tree.xmlOutputBuffer* buf, bytes bstring):
+ cdef const char* c_data = bstring
+ cdef const char* c_end = c_data + len(bstring)
+ cdef const char* c_pos = c_data
+ cdef bint nothing_written = True
+
+ while True:
+ c_pos = cstring_h.memchr(c_pos, b']', c_end - c_pos)
+ if not c_pos:
+ break
+ c_pos += 1
+ next_char = c_pos[0]
+ c_pos += 1
+ if next_char != b']':
+ continue
+ # Found ']]', c_pos points to next character.
+ while c_pos[0] == b']':
+ c_pos += 1
+ if c_pos[0] != b'>':
+ if c_pos == c_end:
+ break
+ # c_pos[0] is neither ']' nor '>', continue with next character.
+ c_pos += 1
+ continue
+
+ # Write section up to ']]' and start next block at trailing '>'.
+ _write_cdata_section(buf, c_data, c_pos)
+ nothing_written = False
+ c_data = c_pos
+ c_pos += 1
+
+ if nothing_written or c_data < c_end:
+ _write_cdata_section(buf, c_data, c_end)
+
+
+############################################################
+# output to file-like objects
+
+cdef object io_open
+from io import open as io_open
+
+cdef object gzip
+import gzip
+
+cdef object getwriter
+from codecs import getwriter
+cdef object utf8_writer = getwriter('utf8')
+
+cdef object contextmanager
+from contextlib import contextmanager
+
+cdef object _open_utf8_file
+
+@contextmanager
+def _open_utf8_file(file, compression=0):
+ file = _getFSPathOrObject(file)
+ if _isString(file):
+ if compression:
+ with gzip.GzipFile(file, mode='wb', compresslevel=compression) as zf:
+ yield utf8_writer(zf)
+ else:
+ with io_open(file, 'w', encoding='utf8') as f:
+ yield f
+ else:
+ if compression:
+ with gzip.GzipFile(fileobj=file, mode='wb', compresslevel=compression) as zf:
+ yield utf8_writer(zf)
+ else:
+ yield utf8_writer(file)
+
+
+@cython.final
+@cython.internal
+cdef class _FilelikeWriter:
+ cdef object _filelike
+ cdef object _close_filelike
+ cdef _ExceptionContext _exc_context
+ cdef _ErrorLog error_log
+
+ def __cinit__(self, filelike, exc_context=None, compression=None, close=False):
+ if compression is not None and compression > 0:
+ filelike = GzipFile(
+ fileobj=filelike, mode='wb', compresslevel=compression)
+ self._close_filelike = filelike.close
+ elif close:
+ self._close_filelike = filelike.close
+ self._filelike = filelike
+ if exc_context is None:
+ self._exc_context = _ExceptionContext()
+ else:
+ self._exc_context = exc_context
+ self.error_log = _ErrorLog()
+
+ cdef tree.xmlOutputBuffer* _createOutputBuffer(
+ self, tree.xmlCharEncodingHandler* enchandler) except NULL:
+ cdef tree.xmlOutputBuffer* c_buffer
+ c_buffer = tree.xmlOutputBufferCreateIO(
+ _writeFilelikeWriter, _closeFilelikeWriter,
+ self, enchandler)
+ if c_buffer is NULL:
+ raise IOError, "Could not create I/O writer context."
+ return c_buffer
+
+ cdef int write(self, char* c_buffer, int size) noexcept:
+ try:
+ if self._filelike is None:
+ raise IOError, "File is already closed"
+ py_buffer = c_buffer[:size]
+ self._filelike.write(py_buffer)
+ except:
+ size = -1
+ self._exc_context._store_raised()
+ finally:
+ return size # and swallow any further exceptions
+
+ cdef int close(self) noexcept:
+ retval = 0
+ try:
+ if self._close_filelike is not None:
+ self._close_filelike()
+ # we should not close the file here as we didn't open it
+ self._filelike = None
+ except:
+ retval = -1
+ self._exc_context._store_raised()
+ finally:
+ return retval # and swallow any further exceptions
+
+cdef int _writeFilelikeWriter(void* ctxt, char* c_buffer, int length) noexcept:
+ return (<_FilelikeWriter>ctxt).write(c_buffer, length)
+
+cdef int _closeFilelikeWriter(void* ctxt) noexcept:
+ return (<_FilelikeWriter>ctxt).close()
+
+cdef _tofilelike(f, _Element element, encoding, doctype, method,
+ bint write_xml_declaration, bint write_doctype,
+ bint pretty_print, bint with_tail, int standalone,
+ int compression):
+ cdef _FilelikeWriter writer = None
+ cdef tree.xmlOutputBuffer* c_buffer
+ cdef tree.xmlCharEncodingHandler* enchandler
+ cdef const_char* c_enc
+ cdef const_xmlChar* c_doctype
+ cdef int error_result
+
+ c_method = _findOutputMethod(method)
+ if c_method == OUTPUT_METHOD_TEXT:
+ data = _textToString(element._c_node, encoding, with_tail)
+ if compression:
+ bytes_out = BytesIO()
+ with GzipFile(fileobj=bytes_out, mode='wb', compresslevel=compression) as gzip_file:
+ gzip_file.write(data)
+ data = bytes_out.getvalue()
+ f = _getFSPathOrObject(f)
+ if _isString(f):
+ filename8 = _encodeFilename(f)
+ with open(filename8, 'wb') as f:
+ f.write(data)
+ else:
+ f.write(data)
+ return
+
+ if encoding is None:
+ c_enc = NULL
+ else:
+ encoding = _utf8(encoding)
+ c_enc = _cstr(encoding)
+ if doctype is None:
+ c_doctype = NULL
+ else:
+ doctype = _utf8(doctype)
+ c_doctype = _xcstr(doctype)
+
+ writer = _create_output_buffer(f, c_enc, compression, &c_buffer, close=False)
+ if writer is None:
+ with nogil:
+ error_result = _serialise_node(
+ c_buffer, c_doctype, c_enc, element._c_node, c_method,
+ write_xml_declaration, write_doctype, pretty_print, with_tail, standalone)
+ else:
+ error_result = _serialise_node(
+ c_buffer, c_doctype, c_enc, element._c_node, c_method,
+ write_xml_declaration, write_doctype, pretty_print, with_tail, standalone)
+
+ if writer is not None:
+ writer._exc_context._raise_if_stored()
+ if error_result != xmlerror.XML_ERR_OK:
+ _raiseSerialisationError(error_result)
+
+
+cdef int _serialise_node(tree.xmlOutputBuffer* c_buffer, const_xmlChar* c_doctype,
+ const_char* c_enc, xmlNode* c_node, int c_method,
+ bint write_xml_declaration, bint write_doctype, bint pretty_print,
+ bint with_tail, int standalone) noexcept nogil:
+ _writeNodeToBuffer(
+ c_buffer, c_node, c_enc, c_doctype, c_method,
+ write_xml_declaration, write_doctype, pretty_print, with_tail, standalone)
+ error_result = c_buffer.error
+ if error_result == xmlerror.XML_ERR_OK:
+ error_result = tree.xmlOutputBufferClose(c_buffer)
+ if error_result != -1:
+ error_result = xmlerror.XML_ERR_OK
+ else:
+ tree.xmlOutputBufferClose(c_buffer)
+ return error_result
+
+
+cdef _FilelikeWriter _create_output_buffer(
+ f, const_char* c_enc, int c_compression,
+ tree.xmlOutputBuffer** c_buffer_ret, bint close):
+ cdef tree.xmlOutputBuffer* c_buffer
+ cdef _FilelikeWriter writer
+ cdef bytes filename8
+ enchandler = tree.xmlFindCharEncodingHandler(c_enc)
+ if enchandler is NULL:
+ raise LookupError(
+ f"unknown encoding: '{c_enc.decode('UTF-8') if c_enc is not NULL else u''}'")
+ try:
+ f = _getFSPathOrObject(f)
+
+ if c_compression and not HAS_ZLIB_COMPRESSION and _isString(f):
+ # Let "_FilelikeWriter" fall back to Python's GzipFile.
+ f = open(f, mode="wb")
+ close = True
+
+ if _isString(f):
+ filename8 = _encodeFilename(f)
+ if b'%' in filename8 and (
+ # Exclude absolute Windows paths and file:// URLs.
+ _isFilePath(filename8) not in (NO_FILE_PATH, ABS_WIN_FILE_PATH)
+ or filename8[:7].lower() == b'file://'):
+ # A file path (not a URL) containing the '%' URL escape character.
+ # libxml2 uses URL-unescaping on these, so escape the path before passing it in.
+ filename8 = filename8.replace(b'%', b'%25')
+ c_buffer = tree.xmlOutputBufferCreateFilename(
+ _cstr(filename8), enchandler, c_compression)
+ if c_buffer is NULL:
+ python.PyErr_SetFromErrno(IOError) # raises IOError
+ writer = None
+ elif hasattr(f, 'write'):
+ writer = _FilelikeWriter(f, compression=c_compression, close=close)
+ c_buffer = writer._createOutputBuffer(enchandler)
+ else:
+ raise TypeError(
+ f"File or filename expected, got '{python._fqtypename(f).decode('UTF-8')}'")
+ except:
+ tree.xmlCharEncCloseFunc(enchandler)
+ raise
+ c_buffer_ret[0] = c_buffer
+ return writer
+
+cdef xmlChar **_convert_ns_prefixes(tree.xmlDict* c_dict, ns_prefixes) except NULL:
+ cdef size_t i, num_ns_prefixes = len(ns_prefixes)
+ # Need to allocate one extra memory block to handle last NULL entry
+ c_ns_prefixes = python.lxml_malloc(num_ns_prefixes + 1, sizeof(xmlChar*))
+ if not c_ns_prefixes:
+ raise MemoryError()
+ i = 0
+ try:
+ for prefix in ns_prefixes:
+ prefix_utf = _utf8(prefix)
+ c_prefix_len = len(prefix_utf)
+ if c_prefix_len > limits.INT_MAX:
+ raise ValueError("Prefix too long")
+ c_prefix = tree.xmlDictExists(c_dict, _xcstr(prefix_utf), c_prefix_len)
+ if c_prefix:
+ # unknown prefixes do not need to get serialised
+ c_ns_prefixes[i] = c_prefix
+ i += 1
+ except:
+ python.lxml_free(c_ns_prefixes)
+ raise
+
+ c_ns_prefixes[i] = NULL # append end marker
+ return c_ns_prefixes
+
+cdef _tofilelikeC14N(f, _Element element, bint exclusive, bint with_comments,
+ int compression, inclusive_ns_prefixes):
+ cdef _FilelikeWriter writer = None
+ cdef tree.xmlOutputBuffer* c_buffer
+ cdef xmlChar **c_inclusive_ns_prefixes = NULL
+ cdef char* c_filename
+ cdef xmlDoc* c_base_doc
+ cdef xmlDoc* c_doc
+ cdef int bytes_count, error = 0
+
+ c_base_doc = element._c_node.doc
+ c_doc = _fakeRootDoc(c_base_doc, element._c_node)
+ try:
+ c_inclusive_ns_prefixes = (
+ _convert_ns_prefixes(c_doc.dict, inclusive_ns_prefixes)
+ if inclusive_ns_prefixes else NULL)
+
+ f = _getFSPathOrObject(f)
+
+ close = False
+ if compression and not HAS_ZLIB_COMPRESSION and _isString(f):
+ # Let "_FilelikeWriter" fall back to Python's GzipFile.
+ f = open(f, mode="wb")
+ close = True
+
+ if _isString(f):
+ filename8 = _encodeFilename(f)
+ c_filename = _cstr(filename8)
+ with nogil:
+ error = c14n.xmlC14NDocSave(
+ c_doc, NULL, exclusive, c_inclusive_ns_prefixes,
+ with_comments, c_filename, compression)
+ elif hasattr(f, 'write'):
+ writer = _FilelikeWriter(f, compression=compression, close=close)
+ c_buffer = writer._createOutputBuffer(NULL)
+ try:
+ with writer.error_log:
+ bytes_count = c14n.xmlC14NDocSaveTo(
+ c_doc, NULL, exclusive, c_inclusive_ns_prefixes,
+ with_comments, c_buffer)
+ finally:
+ error = tree.xmlOutputBufferClose(c_buffer)
+ if bytes_count < 0:
+ error = bytes_count
+ elif error != -1:
+ error = xmlerror.XML_ERR_OK
+ else:
+ raise TypeError(f"File or filename expected, got '{python._fqtypename(f).decode('UTF-8')}'")
+ finally:
+ _destroyFakeDoc(c_base_doc, c_doc)
+ if c_inclusive_ns_prefixes is not NULL:
+ python.lxml_free(c_inclusive_ns_prefixes)
+
+ if writer is not None:
+ writer._exc_context._raise_if_stored()
+
+ if error < 0:
+ message = "C14N failed"
+ if writer is not None:
+ errors = writer.error_log
+ if len(errors):
+ message = errors[0].message
+ raise C14NError(message)
+
+
+# C14N 2.0
+
+def canonicalize(xml_data=None, *, out=None, from_file=None, **options):
+ """Convert XML to its C14N 2.0 serialised form.
+
+ If *out* is provided, it must be a file or file-like object that receives
+ the serialised canonical XML output (text, not bytes) through its ``.write()``
+ method. To write to a file, open it in text mode with encoding "utf-8".
+ If *out* is not provided, this function returns the output as text string.
+
+ Either *xml_data* (an XML string, tree or Element) or *file*
+ (a file path or file-like object) must be provided as input.
+
+ The configuration options are the same as for the ``C14NWriterTarget``.
+ """
+ if xml_data is None and from_file is None:
+ raise ValueError("Either 'xml_data' or 'from_file' must be provided as input")
+
+ sio = None
+ if out is None:
+ sio = out = StringIO()
+
+ target = C14NWriterTarget(out.write, **options)
+
+ if xml_data is not None and not isinstance(xml_data, basestring):
+ _tree_to_target(xml_data, target)
+ return sio.getvalue() if sio is not None else None
+
+ cdef _FeedParser parser = XMLParser(
+ target=target,
+ attribute_defaults=True,
+ collect_ids=False,
+ )
+
+ if xml_data is not None:
+ parser.feed(xml_data)
+ parser.close()
+ elif from_file is not None:
+ try:
+ _parseDocument(from_file, parser, base_url=None)
+ except _TargetParserResult:
+ pass
+
+ return sio.getvalue() if sio is not None else None
+
+
+cdef _tree_to_target(element, target):
+ for event, elem in iterwalk(element, events=('start', 'end', 'start-ns', 'comment', 'pi')):
+ text = None
+ if event == 'start':
+ target.start(elem.tag, elem.attrib)
+ text = elem.text
+ elif event == 'end':
+ target.end(elem.tag)
+ text = elem.tail
+ elif event == 'start-ns':
+ target.start_ns(*elem)
+ continue
+ elif event == 'comment':
+ target.comment(elem.text)
+ text = elem.tail
+ elif event == 'pi':
+ target.pi(elem.target, elem.text)
+ text = elem.tail
+ if text:
+ target.data(text)
+ return target.close()
+
+
+cdef object _looks_like_prefix_name = re.compile(r'^\w+:\w+$', re.UNICODE).match
+
+
+cdef class C14NWriterTarget:
+ """
+ Canonicalization writer target for the XMLParser.
+
+ Serialises parse events to XML C14N 2.0.
+
+ Configuration options:
+
+ - *with_comments*: set to true to include comments
+ - *strip_text*: set to true to strip whitespace before and after text content
+ - *rewrite_prefixes*: set to true to replace namespace prefixes by "n{number}"
+ - *qname_aware_tags*: a set of qname aware tag names in which prefixes
+ should be replaced in text content
+ - *qname_aware_attrs*: a set of qname aware attribute names in which prefixes
+ should be replaced in text content
+ - *exclude_attrs*: a set of attribute names that should not be serialised
+ - *exclude_tags*: a set of tag names that should not be serialised
+ """
+ cdef object _write
+ cdef list _data
+ cdef set _qname_aware_tags
+ cdef object _find_qname_aware_attrs
+ cdef list _declared_ns_stack
+ cdef list _ns_stack
+ cdef dict _prefix_map
+ cdef list _preserve_space
+ cdef tuple _pending_start
+ cdef set _exclude_tags
+ cdef set _exclude_attrs
+ cdef Py_ssize_t _ignored_depth
+ cdef bint _with_comments
+ cdef bint _strip_text
+ cdef bint _rewrite_prefixes
+ cdef bint _root_seen
+ cdef bint _root_done
+
+ def __init__(self, write, *,
+ with_comments=False, strip_text=False, rewrite_prefixes=False,
+ qname_aware_tags=None, qname_aware_attrs=None,
+ exclude_attrs=None, exclude_tags=None):
+ self._write = write
+ self._data = []
+ self._with_comments = with_comments
+ self._strip_text = strip_text
+ self._exclude_attrs = set(exclude_attrs) if exclude_attrs else None
+ self._exclude_tags = set(exclude_tags) if exclude_tags else None
+
+ self._rewrite_prefixes = rewrite_prefixes
+ if qname_aware_tags:
+ self._qname_aware_tags = set(qname_aware_tags)
+ else:
+ self._qname_aware_tags = None
+ if qname_aware_attrs:
+ self._find_qname_aware_attrs = set(qname_aware_attrs).intersection
+ else:
+ self._find_qname_aware_attrs = None
+
+ # Stack with globally and newly declared namespaces as (uri, prefix) pairs.
+ self._declared_ns_stack = [[
+ ("http://www.w3.org/XML/1998/namespace", "xml"),
+ ]]
+ # Stack with user declared namespace prefixes as (uri, prefix) pairs.
+ self._ns_stack = []
+ if not rewrite_prefixes:
+ self._ns_stack.append(_DEFAULT_NAMESPACE_PREFIXES_ITEMS)
+ self._ns_stack.append([])
+ self._prefix_map = {}
+ self._preserve_space = [False]
+ self._pending_start = None
+ self._ignored_depth = 0
+ self._root_seen = False
+ self._root_done = False
+
+ def _iter_namespaces(self, ns_stack):
+ for namespaces in reversed(ns_stack):
+ if namespaces: # almost no element declares new namespaces
+ yield from namespaces
+
+ cdef _resolve_prefix_name(self, prefixed_name):
+ prefix, name = prefixed_name.split(':', 1)
+ for uri, p in self._iter_namespaces(self._ns_stack):
+ if p == prefix:
+ return f'{{{uri}}}{name}'
+ raise ValueError(f'Prefix {prefix} of QName "{prefixed_name}" is not declared in scope')
+
+ cdef _qname(self, qname, uri=None):
+ if uri is None:
+ uri, tag = qname[1:].rsplit('}', 1) if qname[:1] == '{' else ('', qname)
+ else:
+ tag = qname
+
+ prefixes_seen = set()
+ for u, prefix in self._iter_namespaces(self._declared_ns_stack):
+ if u == uri and prefix not in prefixes_seen:
+ return f'{prefix}:{tag}' if prefix else tag, tag, uri
+ prefixes_seen.add(prefix)
+
+ # Not declared yet => add new declaration.
+ if self._rewrite_prefixes:
+ if uri in self._prefix_map:
+ prefix = self._prefix_map[uri]
+ else:
+ prefix = self._prefix_map[uri] = f'n{len(self._prefix_map)}'
+ self._declared_ns_stack[-1].append((uri, prefix))
+ return f'{prefix}:{tag}', tag, uri
+
+ if not uri and '' not in prefixes_seen:
+ # No default namespace declared => no prefix needed.
+ return tag, tag, uri
+
+ for u, prefix in self._iter_namespaces(self._ns_stack):
+ if u == uri:
+ self._declared_ns_stack[-1].append((uri, prefix))
+ return f'{prefix}:{tag}' if prefix else tag, tag, uri
+
+ if not uri:
+ # As soon as a default namespace is defined,
+ # anything that has no namespace (and thus, no prefix) goes there.
+ return tag, tag, uri
+
+ raise ValueError(f'Namespace "{uri}" of name "{tag}" is not declared in scope')
+
+ def data(self, data):
+ if not self._ignored_depth:
+ self._data.append(data)
+
+ cdef _flush(self):
+ cdef unicode data = ''.join(self._data)
+ del self._data[:]
+ if self._strip_text and not self._preserve_space[-1]:
+ data = data.strip()
+ if self._pending_start is not None:
+ (tag, attrs, new_namespaces), self._pending_start = self._pending_start, None
+ qname_text = data if ':' in data and _looks_like_prefix_name(data) else None
+ self._start(tag, attrs, new_namespaces, qname_text)
+ if qname_text is not None:
+ return
+ if data and self._root_seen:
+ self._write(_escape_cdata_c14n(data))
+
+ def start_ns(self, prefix, uri):
+ if self._ignored_depth:
+ return
+ # we may have to resolve qnames in text content
+ if self._data:
+ self._flush()
+ self._ns_stack[-1].append((uri, prefix))
+
+ def start(self, tag, attrs):
+ if self._exclude_tags is not None and (
+ self._ignored_depth or tag in self._exclude_tags):
+ self._ignored_depth += 1
+ return
+ if self._data:
+ self._flush()
+
+ new_namespaces = []
+ self._declared_ns_stack.append(new_namespaces)
+
+ if self._qname_aware_tags is not None and tag in self._qname_aware_tags:
+ # Need to parse text first to see if it requires a prefix declaration.
+ self._pending_start = (tag, attrs, new_namespaces)
+ return
+ self._start(tag, attrs, new_namespaces)
+
+ cdef _start(self, tag, attrs, new_namespaces, qname_text=None):
+ if self._exclude_attrs is not None and attrs:
+ attrs = {k: v for k, v in attrs.items() if k not in self._exclude_attrs}
+
+ qnames = {tag, *attrs}
+ resolved_names = {}
+
+ # Resolve prefixes in attribute and tag text.
+ if qname_text is not None:
+ qname = resolved_names[qname_text] = self._resolve_prefix_name(qname_text)
+ qnames.add(qname)
+ if self._find_qname_aware_attrs is not None and attrs:
+ qattrs = self._find_qname_aware_attrs(attrs)
+ if qattrs:
+ for attr_name in qattrs:
+ value = attrs[attr_name]
+ if _looks_like_prefix_name(value):
+ qname = resolved_names[value] = self._resolve_prefix_name(value)
+ qnames.add(qname)
+ else:
+ qattrs = None
+ else:
+ qattrs = None
+
+ # Assign prefixes in lexicographical order of used URIs.
+ parsed_qnames = {n: self._qname(n) for n in sorted(
+ qnames, key=lambda n: n.split('}', 1))}
+
+ # Write namespace declarations in prefix order ...
+ if new_namespaces:
+ attr_list = [
+ ('xmlns:' + prefix if prefix else 'xmlns', uri)
+ for uri, prefix in new_namespaces
+ ]
+ attr_list.sort()
+ else:
+ # almost always empty
+ attr_list = []
+
+ # ... followed by attributes in URI+name order
+ if attrs:
+ for k, v in sorted(attrs.items()):
+ if qattrs is not None and k in qattrs and v in resolved_names:
+ v = parsed_qnames[resolved_names[v]][0]
+ attr_qname, attr_name, uri = parsed_qnames[k]
+ # No prefix for attributes in default ('') namespace.
+ attr_list.append((attr_qname if uri else attr_name, v))
+
+ # Honour xml:space attributes.
+ space_behaviour = attrs.get('{http://www.w3.org/XML/1998/namespace}space')
+ self._preserve_space.append(
+ space_behaviour == 'preserve' if space_behaviour
+ else self._preserve_space[-1])
+
+ # Write the tag.
+ write = self._write
+ write('<' + parsed_qnames[tag][0])
+ if attr_list:
+ write(''.join([f' {k}="{_escape_attrib_c14n(v)}"' for k, v in attr_list]))
+ write('>')
+
+ # Write the resolved qname text content.
+ if qname_text is not None:
+ write(_escape_cdata_c14n(parsed_qnames[resolved_names[qname_text]][0]))
+
+ self._root_seen = True
+ self._ns_stack.append([])
+
+ def end(self, tag):
+ if self._ignored_depth:
+ self._ignored_depth -= 1
+ return
+ if self._data:
+ self._flush()
+ self._write(f'{self._qname(tag)[0]}>')
+ self._preserve_space.pop()
+ self._root_done = len(self._preserve_space) == 1
+ self._declared_ns_stack.pop()
+ self._ns_stack.pop()
+
+ def comment(self, text):
+ if not self._with_comments:
+ return
+ if self._ignored_depth:
+ return
+ if self._root_done:
+ self._write('\n')
+ elif self._root_seen and self._data:
+ self._flush()
+ self._write(f'')
+ if not self._root_seen:
+ self._write('\n')
+
+ def pi(self, target, data):
+ if self._ignored_depth:
+ return
+ if self._root_done:
+ self._write('\n')
+ elif self._root_seen and self._data:
+ self._flush()
+ self._write(
+ f'{target} {_escape_cdata_c14n(data)}?>' if data else f'{target}?>')
+ if not self._root_seen:
+ self._write('\n')
+
+ def close(self):
+ return None
+
+
+cdef _raise_serialization_error(text):
+ raise TypeError("cannot serialize %r (type %s)" % (text, type(text).__name__))
+
+
+cdef unicode _escape_cdata_c14n(stext):
+ # escape character data
+ cdef unicode text
+ cdef Py_UCS4 ch
+ cdef Py_ssize_t start = 0, pos = 0
+ cdef list substrings = None
+ try:
+ text = unicode(stext)
+ except (TypeError, AttributeError):
+ return _raise_serialization_error(stext)
+
+ for pos, ch in enumerate(text):
+ if ch == '&':
+ escape = '&'
+ elif ch == '<':
+ escape = '<'
+ elif ch == '>':
+ escape = '>'
+ elif ch == '\r':
+ escape = '
'
+ else:
+ continue
+
+ if substrings is None:
+ substrings = []
+ if pos > start:
+ substrings.append(text[start:pos])
+ substrings.append(escape)
+ start = pos + 1
+
+ if substrings is None:
+ return text
+ if pos >= start:
+ substrings.append(text[start:pos+1])
+ return ''.join(substrings)
+
+
+cdef unicode _escape_attrib_c14n(stext):
+ # escape attribute value
+ cdef unicode text
+ cdef Py_UCS4 ch
+ cdef Py_ssize_t start = 0, pos = 0
+ cdef list substrings = None
+ try:
+ text = unicode(stext)
+ except (TypeError, AttributeError):
+ return _raise_serialization_error(stext)
+
+ for pos, ch in enumerate(text):
+ if ch == '&':
+ escape = '&'
+ elif ch == '<':
+ escape = '<'
+ elif ch == '"':
+ escape = '"'
+ elif ch == '\t':
+ escape = ' '
+ elif ch == '\n':
+ escape = '
'
+ elif ch == '\r':
+ escape = '
'
+ else:
+ continue
+
+ if substrings is None:
+ substrings = []
+ if pos > start:
+ substrings.append(text[start:pos])
+ substrings.append(escape)
+ start = pos + 1
+
+ if substrings is None:
+ return text
+ if pos >= start:
+ substrings.append(text[start:pos+1])
+ return ''.join(substrings)
+
+
+# incremental serialisation
+
+cdef class xmlfile:
+ """xmlfile(self, output_file, encoding=None, compression=None, close=False, buffered=True)
+
+ A simple mechanism for incremental XML serialisation.
+
+ Usage example::
+
+ with xmlfile("somefile.xml", encoding='utf-8') as xf:
+ xf.write_declaration(standalone=True)
+ xf.write_doctype('')
+
+ # generate an element (the root element)
+ with xf.element('root'):
+ # write a complete Element into the open root element
+ xf.write(etree.Element('test'))
+
+ # generate and write more Elements, e.g. through iterparse
+ for element in generate_some_elements():
+ # serialise generated elements into the XML file
+ xf.write(element)
+
+ # or write multiple Elements or strings at once
+ xf.write(etree.Element('start'), "text", etree.Element('end'))
+
+ If 'output_file' is a file(-like) object, passing ``close=True`` will
+ close it when exiting the context manager. By default, it is left
+ to the owner to do that. When a file path is used, lxml will take care
+ of opening and closing the file itself. Also, when a compression level
+ is set, lxml will deliberately close the file to make sure all data gets
+ compressed and written.
+
+ Setting ``buffered=False`` will flush the output after each operation,
+ such as opening or closing an ``xf.element()`` block or calling
+ ``xf.write()``. Alternatively, calling ``xf.flush()`` can be used to
+ explicitly flush any pending output when buffering is enabled.
+ """
+ cdef object output_file
+ cdef bytes encoding
+ cdef _IncrementalFileWriter writer
+ cdef _AsyncIncrementalFileWriter async_writer
+ cdef int compresslevel
+ cdef bint close
+ cdef bint buffered
+ cdef int method
+
+ def __init__(self, output_file not None, encoding=None, compression=None,
+ close=False, buffered=True):
+ self.output_file = output_file
+ self.encoding = _utf8orNone(encoding)
+ self.compresslevel = compression or 0
+ self.close = close
+ self.buffered = buffered
+ self.method = OUTPUT_METHOD_XML
+
+ def __enter__(self):
+ assert self.output_file is not None
+ self.writer = _IncrementalFileWriter(
+ self.output_file, self.encoding, self.compresslevel,
+ self.close, self.buffered, self.method)
+ return self.writer
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if self.writer is not None:
+ old_writer, self.writer = self.writer, None
+ raise_on_error = exc_type is None
+ old_writer._close(raise_on_error)
+ if self.close:
+ self.output_file = None
+
+ async def __aenter__(self):
+ assert self.output_file is not None
+ if isinstance(self.output_file, basestring):
+ raise TypeError("Cannot asynchronously write to a plain file")
+ if not hasattr(self.output_file, 'write'):
+ raise TypeError("Output file needs an async .write() method")
+ self.async_writer = _AsyncIncrementalFileWriter(
+ self.output_file, self.encoding, self.compresslevel,
+ self.close, self.buffered, self.method)
+ return self.async_writer
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ if self.async_writer is not None:
+ old_writer, self.async_writer = self.async_writer, None
+ raise_on_error = exc_type is None
+ await old_writer._close(raise_on_error)
+ if self.close:
+ self.output_file = None
+
+
+cdef class htmlfile(xmlfile):
+ """htmlfile(self, output_file, encoding=None, compression=None, close=False, buffered=True)
+
+ A simple mechanism for incremental HTML serialisation. Works the same as
+ xmlfile.
+ """
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.method = OUTPUT_METHOD_HTML
+
+
+cdef enum _IncrementalFileWriterStatus:
+ WRITER_STARTING = 0
+ WRITER_DECL_WRITTEN = 1
+ WRITER_DTD_WRITTEN = 2
+ WRITER_IN_ELEMENT = 3
+ WRITER_FINISHED = 4
+
+
+@cython.final
+@cython.internal
+cdef class _IncrementalFileWriter:
+ cdef tree.xmlOutputBuffer* _c_out
+ cdef bytes _encoding
+ cdef const_char* _c_encoding
+ cdef _FilelikeWriter _target
+ cdef list _element_stack
+ cdef int _status
+ cdef int _method
+ cdef bint _buffered
+
+ def __cinit__(self, outfile, bytes encoding, int compresslevel, bint close,
+ bint buffered, int method):
+ self._status = WRITER_STARTING
+ self._element_stack = []
+ if encoding is None:
+ # We always need a document encoding to make the attribute serialisation
+ # of libxml2 identical to ours.
+ encoding = b'ASCII'
+ self._encoding = encoding
+ self._c_encoding = _cstr(encoding)
+ self._buffered = buffered
+ self._target = _create_output_buffer(
+ outfile, self._c_encoding, compresslevel, &self._c_out, close)
+ self._method = method
+
+ def __dealloc__(self):
+ if self._c_out is not NULL:
+ tree.xmlOutputBufferClose(self._c_out)
+
+ def write_declaration(self, version=None, standalone=None, doctype=None):
+ """write_declaration(self, version=None, standalone=None, doctype=None)
+
+ Write an XML declaration and (optionally) a doctype into the file.
+ """
+ assert self._c_out is not NULL
+ cdef const_xmlChar* c_version
+ cdef int c_standalone
+ if self._method != OUTPUT_METHOD_XML:
+ raise LxmlSyntaxError("only XML documents have declarations")
+ if self._status >= WRITER_DECL_WRITTEN:
+ raise LxmlSyntaxError("XML declaration already written")
+ version = _utf8orNone(version)
+ c_version = _xcstr(version) if version is not None else NULL
+ doctype = _utf8orNone(doctype)
+ if standalone is None:
+ c_standalone = -1
+ else:
+ c_standalone = 1 if standalone else 0
+ _writeDeclarationToBuffer(self._c_out, c_version, self._c_encoding, c_standalone)
+ if doctype is not None:
+ _writeDoctype(self._c_out, _xcstr(doctype))
+ self._status = WRITER_DTD_WRITTEN
+ else:
+ self._status = WRITER_DECL_WRITTEN
+ if not self._buffered:
+ tree.xmlOutputBufferFlush(self._c_out)
+ self._handle_error(self._c_out.error)
+
+ def write_doctype(self, doctype):
+ """write_doctype(self, doctype)
+
+ Writes the given doctype declaration verbatimly into the file.
+ """
+ assert self._c_out is not NULL
+ if doctype is None:
+ return
+ if self._status >= WRITER_DTD_WRITTEN:
+ raise LxmlSyntaxError("DOCTYPE already written or cannot write it here")
+ doctype = _utf8(doctype)
+ _writeDoctype(self._c_out, _xcstr(doctype))
+ self._status = WRITER_DTD_WRITTEN
+ if not self._buffered:
+ tree.xmlOutputBufferFlush(self._c_out)
+ self._handle_error(self._c_out.error)
+
+ def method(self, method):
+ """method(self, method)
+
+ Returns a context manager that overrides and restores the output method.
+ method is one of (None, 'xml', 'html') where None means 'xml'.
+ """
+ assert self._c_out is not NULL
+ c_method = self._method if method is None else _findOutputMethod(method)
+ return _MethodChanger(self, c_method)
+
+ def element(self, tag, attrib=None, nsmap=None, method=None, **_extra):
+ """element(self, tag, attrib=None, nsmap=None, method, **_extra)
+
+ Returns a context manager that writes an opening and closing tag.
+ method is one of (None, 'xml', 'html') where None means 'xml'.
+ """
+ assert self._c_out is not NULL
+ attributes = []
+ if attrib is not None:
+ for name, value in _iter_attrib(attrib):
+ if name not in _extra:
+ ns, name = _getNsTag(name)
+ attributes.append((ns, name, _utf8(value)))
+ if _extra:
+ for name, value in _extra.iteritems():
+ ns, name = _getNsTag(name)
+ attributes.append((ns, name, _utf8(value)))
+ reversed_nsmap = {}
+ if nsmap:
+ for prefix, ns in nsmap.items():
+ if prefix is not None:
+ prefix = _utf8(prefix)
+ _prefixValidOrRaise(prefix)
+ reversed_nsmap[_utf8(ns)] = prefix
+ ns, name = _getNsTag(tag)
+
+ c_method = self._method if method is None else _findOutputMethod(method)
+
+ return _FileWriterElement(self, (ns, name, attributes, reversed_nsmap), c_method)
+
+ cdef _write_qname(self, bytes name, bytes prefix):
+ if prefix: # empty bytes for no prefix (not None to allow sorting)
+ tree.xmlOutputBufferWrite(self._c_out, len(prefix), _cstr(prefix))
+ tree.xmlOutputBufferWrite(self._c_out, 1, ':')
+ tree.xmlOutputBufferWrite(self._c_out, len(name), _cstr(name))
+
+ cdef _write_start_element(self, element_config):
+ if self._status > WRITER_IN_ELEMENT:
+ raise LxmlSyntaxError("cannot append trailing element to complete XML document")
+ ns, name, attributes, nsmap = element_config
+ flat_namespace_map, new_namespaces = self._collect_namespaces(nsmap)
+ prefix = self._find_prefix(ns, flat_namespace_map, new_namespaces)
+ tree.xmlOutputBufferWrite(self._c_out, 1, '<')
+ self._write_qname(name, prefix)
+
+ self._write_attributes_and_namespaces(
+ attributes, flat_namespace_map, new_namespaces)
+
+ tree.xmlOutputBufferWrite(self._c_out, 1, '>')
+ if not self._buffered:
+ tree.xmlOutputBufferFlush(self._c_out)
+ self._handle_error(self._c_out.error)
+
+ self._element_stack.append((ns, name, prefix, flat_namespace_map))
+ self._status = WRITER_IN_ELEMENT
+
+ cdef _write_attributes_and_namespaces(self, list attributes,
+ dict flat_namespace_map,
+ list new_namespaces):
+ if attributes:
+ # _find_prefix() may append to new_namespaces => build them first
+ attributes = [
+ (self._find_prefix(ns, flat_namespace_map, new_namespaces), name, value)
+ for ns, name, value in attributes ]
+ if new_namespaces:
+ new_namespaces.sort()
+ self._write_attributes_list(new_namespaces)
+ if attributes:
+ self._write_attributes_list(attributes)
+
+ cdef _write_attributes_list(self, list attributes):
+ for prefix, name, value in attributes:
+ tree.xmlOutputBufferWrite(self._c_out, 1, ' ')
+ self._write_qname(name, prefix)
+ tree.xmlOutputBufferWrite(self._c_out, 2, '="')
+ _write_attr_string(self._c_out, _cstr(value))
+
+ tree.xmlOutputBufferWrite(self._c_out, 1, '"')
+
+ cdef _write_end_element(self, element_config):
+ if self._status != WRITER_IN_ELEMENT:
+ raise LxmlSyntaxError("not in an element")
+ if not self._element_stack or self._element_stack[-1][:2] != element_config[:2]:
+ raise LxmlSyntaxError("inconsistent exit action in context manager")
+
+ # If previous write operations failed, the context manager exit might still call us.
+ # That is ok, but we stop writing closing tags and handling errors in that case.
+ # For all non-I/O errors, we continue writing closing tags if we can.
+ ok_to_write = self._c_out.error == xmlerror.XML_ERR_OK
+
+ name, prefix = self._element_stack.pop()[1:3]
+ if ok_to_write:
+ tree.xmlOutputBufferWrite(self._c_out, 2, '')
+ self._write_qname(name, prefix)
+ tree.xmlOutputBufferWrite(self._c_out, 1, '>')
+
+ if not self._element_stack:
+ self._status = WRITER_FINISHED
+ if ok_to_write:
+ if not self._buffered:
+ tree.xmlOutputBufferFlush(self._c_out)
+ self._handle_error(self._c_out.error)
+
+ cdef _find_prefix(self, bytes href, dict flat_namespaces_map, list new_namespaces):
+ if href is None:
+ return None
+ if href in flat_namespaces_map:
+ return flat_namespaces_map[href]
+ # need to create a new prefix
+ prefixes = flat_namespaces_map.values()
+ i = 0
+ while True:
+ prefix = _utf8('ns%d' % i)
+ if prefix not in prefixes:
+ new_namespaces.append((b'xmlns', prefix, href))
+ flat_namespaces_map[href] = prefix
+ return prefix
+ i += 1
+
+ cdef _collect_namespaces(self, dict nsmap):
+ new_namespaces = []
+ flat_namespaces_map = {}
+ for ns, prefix in nsmap.iteritems():
+ flat_namespaces_map[ns] = prefix
+ if prefix is None:
+ # use empty bytes rather than None to allow sorting
+ new_namespaces.append((b'', b'xmlns', ns))
+ else:
+ new_namespaces.append((b'xmlns', prefix, ns))
+ # merge in flat namespace map of parent
+ if self._element_stack:
+ for ns, prefix in (self._element_stack[-1][-1]).iteritems():
+ if flat_namespaces_map.get(ns) is None:
+ # unknown or empty prefix => prefer a 'real' prefix
+ flat_namespaces_map[ns] = prefix
+ return flat_namespaces_map, new_namespaces
+
+ def write(self, *args, bint with_tail=True, bint pretty_print=False, method=None):
+ """write(self, *args, with_tail=True, pretty_print=False, method=None)
+
+ Write subtrees or strings into the file.
+
+ If method is not None, it should be one of ('html', 'xml', 'text')
+ to temporarily override the output method.
+ """
+ assert self._c_out is not NULL
+ c_method = self._method if method is None else _findOutputMethod(method)
+
+ for content in args:
+ if _isString(content):
+ if self._status != WRITER_IN_ELEMENT:
+ if self._status > WRITER_IN_ELEMENT or content.strip():
+ raise LxmlSyntaxError("not in an element")
+ bstring = _utf8(content)
+ if not bstring:
+ continue
+
+ ns, name, _, _ = self._element_stack[-1]
+ if (c_method == OUTPUT_METHOD_HTML and
+ ns in (None, b'http://www.w3.org/1999/xhtml') and
+ name in (b'script', b'style')):
+ tree.xmlOutputBufferWrite(self._c_out, len(bstring), _cstr(bstring))
+
+ else:
+ tree.xmlOutputBufferWriteEscape(self._c_out, _xcstr(bstring), NULL)
+
+ elif isinstance(content, CDATA):
+ if self._status > WRITER_IN_ELEMENT:
+ raise LxmlSyntaxError("not in an element")
+ _write_cdata_string(self._c_out, (content)._utf8_data)
+
+ elif iselement(content):
+ if self._status > WRITER_IN_ELEMENT:
+ raise LxmlSyntaxError("cannot append trailing element to complete XML document")
+ _writeNodeToBuffer(self._c_out, (<_Element>content)._c_node,
+ self._c_encoding, NULL, c_method,
+ False, False, pretty_print, with_tail, False)
+ if (<_Element>content)._c_node.type == tree.XML_ELEMENT_NODE:
+ if not self._element_stack:
+ self._status = WRITER_FINISHED
+
+ elif content is not None:
+ raise TypeError(
+ f"got invalid input value of type {type(content)}, expected string, CDATA or Element")
+
+ self._handle_error(self._c_out.error)
+
+ if not self._buffered:
+ tree.xmlOutputBufferFlush(self._c_out)
+ self._handle_error(self._c_out.error)
+
+ def flush(self):
+ """flush(self)
+
+ Write any pending content of the current output buffer to the stream.
+ """
+ assert self._c_out is not NULL
+ tree.xmlOutputBufferFlush(self._c_out)
+ self._handle_error(self._c_out.error)
+
+ cdef _close(self, bint raise_on_error):
+ if raise_on_error:
+ if self._status < WRITER_IN_ELEMENT:
+ raise LxmlSyntaxError("no content written")
+ if self._element_stack:
+ raise LxmlSyntaxError("pending open tags on close")
+ error_result = self._c_out.error
+ if error_result == xmlerror.XML_ERR_OK:
+ error_result = tree.xmlOutputBufferClose(self._c_out)
+ if error_result != -1:
+ error_result = xmlerror.XML_ERR_OK
+ else:
+ tree.xmlOutputBufferClose(self._c_out)
+ self._status = WRITER_FINISHED
+ self._c_out = NULL
+ del self._element_stack[:]
+ if raise_on_error:
+ self._handle_error(error_result)
+
+ cdef _handle_error(self, int error_result):
+ if error_result != xmlerror.XML_ERR_OK:
+ if self._target is not None:
+ self._target._exc_context._raise_if_stored()
+ _raiseSerialisationError(error_result)
+
+
+@cython.final
+@cython.internal
+cdef class _AsyncDataWriter:
+ cdef list _data
+ def __cinit__(self):
+ self._data = []
+
+ cdef bytes collect(self):
+ data = b''.join(self._data)
+ del self._data[:]
+ return data
+
+ def write(self, data):
+ self._data.append(data)
+
+ def close(self):
+ pass
+
+
+@cython.final
+@cython.internal
+cdef class _AsyncIncrementalFileWriter:
+ cdef _IncrementalFileWriter _writer
+ cdef _AsyncDataWriter _buffer
+ cdef object _async_outfile
+ cdef int _flush_after_writes
+ cdef bint _should_close
+ cdef bint _buffered
+
+ def __cinit__(self, async_outfile, bytes encoding, int compresslevel, bint close,
+ bint buffered, int method):
+ self._flush_after_writes = 20
+ self._async_outfile = async_outfile
+ self._should_close = close
+ self._buffered = buffered
+ self._buffer = _AsyncDataWriter()
+ self._writer = _IncrementalFileWriter(
+ self._buffer, encoding, compresslevel, close=True, buffered=False, method=method)
+
+ cdef bytes _flush(self):
+ if not self._buffered or len(self._buffer._data) > self._flush_after_writes:
+ return self._buffer.collect()
+ return None
+
+ async def flush(self):
+ self._writer.flush()
+ data = self._buffer.collect()
+ if data:
+ await self._async_outfile.write(data)
+
+ async def write_declaration(self, version=None, standalone=None, doctype=None):
+ self._writer.write_declaration(version, standalone, doctype)
+ data = self._flush()
+ if data:
+ await self._async_outfile.write(data)
+
+ async def write_doctype(self, doctype):
+ self._writer.write_doctype(doctype)
+ data = self._flush()
+ if data:
+ await self._async_outfile.write(data)
+
+ async def write(self, *args, with_tail=True, pretty_print=False, method=None):
+ self._writer.write(*args, with_tail=with_tail, pretty_print=pretty_print, method=method)
+ data = self._flush()
+ if data:
+ await self._async_outfile.write(data)
+
+ def method(self, method):
+ return self._writer.method(method)
+
+ def element(self, tag, attrib=None, nsmap=None, method=None, **_extra):
+ element_writer = self._writer.element(tag, attrib, nsmap, method, **_extra)
+ return _AsyncFileWriterElement(element_writer, self)
+
+ async def _close(self, bint raise_on_error):
+ self._writer._close(raise_on_error)
+ data = self._buffer.collect()
+ if data:
+ await self._async_outfile.write(data)
+ if self._should_close:
+ await self._async_outfile.close()
+
+
+@cython.final
+@cython.internal
+cdef class _AsyncFileWriterElement:
+ cdef _FileWriterElement _element_writer
+ cdef _AsyncIncrementalFileWriter _writer
+
+ def __cinit__(self, _FileWriterElement element_writer not None,
+ _AsyncIncrementalFileWriter writer not None):
+ self._element_writer = element_writer
+ self._writer = writer
+
+ async def __aenter__(self):
+ self._element_writer.__enter__()
+ data = self._writer._flush()
+ if data:
+ await self._writer._async_outfile.write(data)
+
+ async def __aexit__(self, *args):
+ self._element_writer.__exit__(*args)
+ data = self._writer._flush()
+ if data:
+ await self._writer._async_outfile.write(data)
+
+
+@cython.final
+@cython.internal
+@cython.freelist(8)
+cdef class _FileWriterElement:
+ cdef _IncrementalFileWriter _writer
+ cdef object _element
+ cdef int _new_method
+ cdef int _old_method
+
+ def __cinit__(self, _IncrementalFileWriter writer not None, element_config, int method):
+ self._writer = writer
+ self._element = element_config
+ self._new_method = method
+ self._old_method = writer._method
+
+ def __enter__(self):
+ self._writer._method = self._new_method
+ self._writer._write_start_element(self._element)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self._writer._write_end_element(self._element)
+ self._writer._method = self._old_method
+
+
+@cython.final
+@cython.internal
+@cython.freelist(8)
+cdef class _MethodChanger:
+ cdef _IncrementalFileWriter _writer
+ cdef int _new_method
+ cdef int _old_method
+ cdef bint _entered
+ cdef bint _exited
+
+ def __cinit__(self, _IncrementalFileWriter writer not None, int method):
+ self._writer = writer
+ self._new_method = method
+ self._old_method = writer._method
+ self._entered = False
+ self._exited = False
+
+ def __enter__(self):
+ if self._entered:
+ raise LxmlSyntaxError("Inconsistent enter action in context manager")
+ self._writer._method = self._new_method
+ self._entered = True
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if self._exited:
+ raise LxmlSyntaxError("Inconsistent exit action in context manager")
+ if self._writer._method != self._new_method:
+ raise LxmlSyntaxError("Method changed outside of context manager")
+ self._writer._method = self._old_method
+ self._exited = True
+
+ async def __aenter__(self):
+ # for your async convenience
+ return self.__enter__()
+
+ async def __aexit__(self, *args):
+ # for your async convenience
+ return self.__exit__(*args)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xinclude.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xinclude.pxi
new file mode 100644
index 0000000000000000000000000000000000000000..5c9ac45096efb2250e268dd2eed9ade07c2ca998
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xinclude.pxi
@@ -0,0 +1,67 @@
+# XInclude processing
+
+from lxml.includes cimport xinclude
+
+
+cdef class XIncludeError(LxmlError):
+ """Error during XInclude processing.
+ """
+
+
+cdef class XInclude:
+ """XInclude(self)
+ XInclude processor.
+
+ Create an instance and call it on an Element to run XInclude
+ processing.
+ """
+ cdef _ErrorLog _error_log
+ def __init__(self):
+ self._error_log = _ErrorLog()
+
+ @property
+ def error_log(self):
+ assert self._error_log is not None, "XInclude instance not initialised"
+ return self._error_log.copy()
+
+ def __call__(self, _Element node not None):
+ "__call__(self, node)"
+ # We cannot pass the XML_PARSE_NOXINCNODE option as this would free
+ # the XInclude nodes - there may still be Python references to them!
+ # Therefore, we allow XInclude nodes to be converted to
+ # XML_XINCLUDE_START nodes. XML_XINCLUDE_END nodes are added as
+ # siblings. Tree traversal will simply ignore them as they are not
+ # typed as elements. The included fragment is added between the two,
+ # i.e. as a sibling, which does not conflict with traversal.
+ cdef int result
+ _assertValidNode(node)
+ assert self._error_log is not None, "XInclude processor not initialised"
+ if node._doc._parser is not None:
+ parse_options = node._doc._parser._parse_options
+ context = node._doc._parser._getParserContext()
+ c_context = context
+ else:
+ parse_options = 0
+ context = None
+ c_context = NULL
+
+ self._error_log.connect()
+ if tree.LIBXML_VERSION < 20704 or not c_context:
+ __GLOBAL_PARSER_CONTEXT.pushImpliedContext(context)
+ with nogil:
+ orig_loader = _register_document_loader()
+ if c_context:
+ result = xinclude.xmlXIncludeProcessTreeFlagsData(
+ node._c_node, parse_options, c_context)
+ else:
+ result = xinclude.xmlXIncludeProcessTree(node._c_node)
+ _reset_document_loader(orig_loader)
+ if tree.LIBXML_VERSION < 20704 or not c_context:
+ __GLOBAL_PARSER_CONTEXT.popImpliedContext()
+ self._error_log.disconnect()
+
+ if result == -1:
+ raise XIncludeError(
+ self._error_log._buildExceptionMessage(
+ "XInclude processing failed"),
+ self._error_log)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xmlerror.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xmlerror.pxi
new file mode 100644
index 0000000000000000000000000000000000000000..3be24d21230b3c2c9a7d34bfc31ef969147877d6
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xmlerror.pxi
@@ -0,0 +1,1662 @@
+# DEBUG and error logging
+
+from lxml.includes cimport xmlerror
+from lxml cimport cvarargs
+
+DEF GLOBAL_ERROR_LOG = "_GlobalErrorLog"
+DEF XSLT_ERROR_LOG = "_XSLTErrorLog"
+
+# module level API functions
+
+def clear_error_log():
+ """clear_error_log()
+
+ Clear the global error log. Note that this log is already bound to a
+ fixed size.
+
+ Note: since lxml 2.2, the global error log is local to a thread
+ and this function will only clear the global error log of the
+ current thread.
+ """
+ _getThreadErrorLog(GLOBAL_ERROR_LOG).clear()
+
+
+# setup for global log:
+
+cdef void _initThreadLogging() noexcept:
+ # Disable generic error lines from libxml2.
+ _connectGenericErrorLog(None)
+
+ # Divert XSLT error messages to the global XSLT error log instead of stderr.
+ xslt.xsltSetGenericErrorFunc(NULL, _receiveXSLTError)
+
+
+# Logging classes
+
+@cython.final
+@cython.freelist(16)
+cdef class _LogEntry:
+ """A log message entry from an error log.
+
+ Attributes:
+
+ - message: the message text
+ - domain: the domain ID (see lxml.etree.ErrorDomains)
+ - type: the message type ID (see lxml.etree.ErrorTypes)
+ - level: the log level ID (see lxml.etree.ErrorLevels)
+ - line: the line at which the message originated (if applicable)
+ - column: the character column at which the message originated (if applicable)
+ - filename: the name of the file in which the message originated (if applicable)
+ - path: the location in which the error was found (if available)
+ """
+ cdef readonly int domain
+ cdef readonly int type
+ cdef readonly int level
+ cdef readonly long line
+ cdef readonly int column
+ cdef basestring _message
+ cdef basestring _filename
+ cdef char* _c_message
+ cdef xmlChar* _c_filename
+ cdef xmlChar* _c_path
+
+ def __dealloc__(self):
+ tree.xmlFree(self._c_message)
+ tree.xmlFree(self._c_filename)
+ tree.xmlFree(self._c_path)
+
+ @cython.final
+ cdef int _setError(self, const xmlerror.xmlError* error) except -1:
+ self.domain = error.domain
+ self.type = error.code
+ self.level = error.level
+ self.line = error.line
+ self.column = error.int2
+ self._c_message = NULL
+ self._c_filename = NULL
+ self._c_path = NULL
+ if (error.message is NULL or
+ error.message[0] == b'\0' or
+ error.message[0] == b'\n' and error.message[1] == b'\0'):
+ self._message = "unknown error"
+ else:
+ self._message = None
+ self._c_message = tree.xmlStrdup(
+ error.message)
+ if not self._c_message:
+ raise MemoryError()
+ if error.file is NULL:
+ self._filename = ''
+ else:
+ self._filename = None
+ self._c_filename = tree.xmlStrdup( error.file)
+ if not self._c_filename:
+ raise MemoryError()
+ if error.node is not NULL:
+ self._c_path = tree.xmlGetNodePath( error.node)
+ c_line = tree.xmlGetLineNo( error.node)
+ if c_line > limits.INT_MAX:
+ self.line = c_line
+
+ @cython.final
+ cdef _setGeneric(self, int domain, int type, int level, long line,
+ message, filename):
+ self.domain = domain
+ self.type = type
+ self.level = level
+ self.line = line
+ self.column = 0
+ self._message = message
+ self._filename = filename
+ self._c_path = NULL
+
+ def __repr__(self):
+ return "%s:%d:%d:%s:%s:%s: %s" % (
+ self.filename, self.line, self.column, self.level_name,
+ self.domain_name, self.type_name, self.message)
+
+ @property
+ def domain_name(self):
+ """The name of the error domain. See lxml.etree.ErrorDomains
+ """
+ return ErrorDomains._getName(self.domain, "unknown")
+
+ @property
+ def type_name(self):
+ """The name of the error type. See lxml.etree.ErrorTypes
+ """
+ if self.domain == ErrorDomains.RELAXNGV:
+ getName = RelaxNGErrorTypes._getName
+ else:
+ getName = ErrorTypes._getName
+ return getName(self.type, "unknown")
+
+ @property
+ def level_name(self):
+ """The name of the error level. See lxml.etree.ErrorLevels
+ """
+ return ErrorLevels._getName(self.level, "unknown")
+
+ @property
+ def message(self):
+ """The log message string.
+ """
+ cdef size_t size
+ if self._message is not None:
+ return self._message
+ if self._c_message is NULL:
+ return None
+ size = cstring_h.strlen(self._c_message)
+ if size > 0 and self._c_message[size-1] == b'\n':
+ size -= 1 # strip EOL
+ # cannot use funicode() here because the message may contain
+ # byte encoded file paths etc.
+ try:
+ self._message = self._c_message[:size].decode('utf8')
+ except UnicodeDecodeError:
+ try:
+ self._message = self._c_message[:size].decode(
+ 'ascii', 'backslashreplace')
+ except UnicodeDecodeError:
+ self._message = ''
+ if self._c_message:
+ # clean up early
+ tree.xmlFree(self._c_message)
+ self._c_message = NULL
+ return self._message
+
+ @property
+ def filename(self):
+ """The file path where the report originated, if any.
+ """
+ if self._filename is None:
+ if self._c_filename is not NULL:
+ self._filename = _decodeFilename(self._c_filename)
+ # clean up early
+ tree.xmlFree(self._c_filename)
+ self._c_filename = NULL
+ return self._filename
+
+ @property
+ def path(self):
+ """The XPath for the node where the error was detected.
+ """
+ return funicode(self._c_path) if self._c_path is not NULL else None
+
+
+cdef class _BaseErrorLog:
+ cdef _LogEntry _first_error
+ cdef readonly object last_error
+ def __init__(self, first_error, last_error):
+ self._first_error = first_error
+ self.last_error = last_error
+
+ cpdef copy(self):
+ return _BaseErrorLog(self._first_error, self.last_error)
+
+ def __repr__(self):
+ return ''
+
+ cpdef receive(self, _LogEntry entry):
+ pass
+
+ @cython.final
+ cdef int _receive(self, const xmlerror.xmlError* error) except -1:
+ cdef bint is_error
+ cdef _LogEntry entry
+ cdef _BaseErrorLog global_log
+ entry = _LogEntry.__new__(_LogEntry)
+ entry._setError(error)
+ is_error = error.level == xmlerror.XML_ERR_ERROR or \
+ error.level == xmlerror.XML_ERR_FATAL
+ global_log = _getThreadErrorLog(GLOBAL_ERROR_LOG)
+ if global_log is not self:
+ global_log.receive(entry)
+ if is_error:
+ global_log.last_error = entry
+ self.receive(entry)
+ if is_error:
+ self.last_error = entry
+
+ @cython.final
+ cdef int _receiveGeneric(self, int domain, int type, int level, long line,
+ message, filename) except -1:
+ cdef bint is_error
+ cdef _LogEntry entry
+ cdef _BaseErrorLog global_log
+ entry = _LogEntry.__new__(_LogEntry)
+ entry._setGeneric(domain, type, level, line, message, filename)
+ is_error = level == xmlerror.XML_ERR_ERROR or \
+ level == xmlerror.XML_ERR_FATAL
+ global_log = _getThreadErrorLog(GLOBAL_ERROR_LOG)
+ if global_log is not self:
+ global_log.receive(entry)
+ if is_error:
+ global_log.last_error = entry
+ self.receive(entry)
+ if is_error:
+ self.last_error = entry
+
+ @cython.final
+ cdef _buildParseException(self, exctype, default_message):
+ code = xmlerror.XML_ERR_INTERNAL_ERROR
+ if self._first_error is None:
+ return exctype(default_message, code, 0, 0)
+ message = self._first_error.message
+ if message:
+ code = self._first_error.type
+ else:
+ message = default_message
+ line = self._first_error.line
+ column = self._first_error.column
+ filename = self._first_error.filename
+ if line > 0:
+ if column > 0:
+ message = f"{message}, line {line}, column {column}"
+ else:
+ message = f"{message}, line {line}"
+ return exctype(message, code, line, column, filename)
+
+ @cython.final
+ cdef _buildExceptionMessage(self, default_message):
+ if self._first_error is None:
+ return default_message
+ if self._first_error.message:
+ message = self._first_error.message
+ elif default_message is None:
+ return None
+ else:
+ message = default_message
+ if self._first_error.line > 0:
+ if self._first_error.column > 0:
+ message = f"{message}, line {self._first_error.line}, column {self._first_error.column}"
+ else:
+ message = f"{message}, line {self._first_error.line}"
+ return message
+
+cdef class _ListErrorLog(_BaseErrorLog):
+ "Immutable base version of a list based error log."
+ cdef list _entries
+ cdef int _offset
+ def __init__(self, entries, first_error, last_error):
+ if entries:
+ if first_error is None:
+ first_error = entries[0]
+ if last_error is None:
+ last_error = entries[-1]
+ _BaseErrorLog.__init__(self, first_error, last_error)
+ self._entries = entries
+
+ cpdef copy(self):
+ """Creates a shallow copy of this error log. Reuses the list of
+ entries.
+ """
+ cdef _ListErrorLog log = _ListErrorLog(
+ self._entries, self._first_error, self.last_error)
+ log._offset = self._offset
+ return log
+
+ def __iter__(self):
+ entries = self._entries
+ if self._offset:
+ entries = islice(entries, self._offset)
+ return iter(entries)
+
+ def __repr__(self):
+ return '\n'.join([repr(entry) for entry in self])
+
+ def __getitem__(self, index):
+ if self._offset:
+ index += self._offset
+ return self._entries[index]
+
+ def __len__(self):
+ return len(self._entries) - self._offset
+
+ def __contains__(self, error_type):
+ cdef Py_ssize_t i
+ for i, entry in enumerate(self._entries):
+ if i < self._offset:
+ continue
+ if entry.type == error_type:
+ return True
+ return False
+
+ def __bool__(self):
+ return len(self._entries) > self._offset
+
+ def filter_domains(self, domains):
+ """Filter the errors by the given domains and return a new error log
+ containing the matches.
+ """
+ cdef _LogEntry entry
+ if isinstance(domains, int):
+ domains = (domains,)
+ filtered = [entry for entry in self if entry.domain in domains]
+ return _ListErrorLog(filtered, None, None)
+
+ def filter_types(self, types):
+ """filter_types(self, types)
+
+ Filter the errors by the given types and return a new error
+ log containing the matches.
+ """
+ cdef _LogEntry entry
+ if isinstance(types, int):
+ types = (types,)
+ filtered = [entry for entry in self if entry.type in types]
+ return _ListErrorLog(filtered, None, None)
+
+ def filter_levels(self, levels):
+ """filter_levels(self, levels)
+
+ Filter the errors by the given error levels and return a new
+ error log containing the matches.
+ """
+ cdef _LogEntry entry
+ if isinstance(levels, int):
+ levels = (levels,)
+ filtered = [entry for entry in self if entry.level in levels]
+ return _ListErrorLog(filtered, None, None)
+
+ def filter_from_level(self, level):
+ """filter_from_level(self, level)
+
+ Return a log with all messages of the requested level of worse.
+ """
+ cdef _LogEntry entry
+ filtered = [entry for entry in self if entry.level >= level]
+ return _ListErrorLog(filtered, None, None)
+
+ def filter_from_fatals(self):
+ """filter_from_fatals(self)
+
+ Convenience method to get all fatal error messages.
+ """
+ return self.filter_from_level(ErrorLevels.FATAL)
+
+ def filter_from_errors(self):
+ """filter_from_errors(self)
+
+ Convenience method to get all error messages or worse.
+ """
+ return self.filter_from_level(ErrorLevels.ERROR)
+
+ def filter_from_warnings(self):
+ """filter_from_warnings(self)
+
+ Convenience method to get all warnings or worse.
+ """
+ return self.filter_from_level(ErrorLevels.WARNING)
+
+
+@cython.final
+@cython.internal
+cdef class _ErrorLogContext:
+ """
+ Error log context for the 'with' statement.
+ Stores a reference to the current callbacks to allow for
+ recursively stacked log contexts.
+ """
+ cdef xmlerror.xmlStructuredErrorFunc old_error_func
+ cdef void* old_error_context
+ cdef xmlerror.xmlGenericErrorFunc old_xslt_error_func
+ cdef void* old_xslt_error_context
+ cdef _BaseErrorLog old_xslt_error_log
+
+ cdef int push_error_log(self, _BaseErrorLog log) except -1:
+ self.old_error_func = xmlerror.xmlStructuredError
+ self.old_error_context = xmlerror.xmlStructuredErrorContext
+ xmlerror.xmlSetStructuredErrorFunc(
+ log, _receiveError)
+
+ # xslt.xsltSetGenericErrorFunc() is not thread-local => keep error log in TLS
+ self.old_xslt_error_func = xslt.xsltGenericError
+ self.old_xslt_error_context = xslt.xsltGenericErrorContext
+ self.old_xslt_error_log = _getThreadErrorLog(XSLT_ERROR_LOG)
+ _setThreadErrorLog(XSLT_ERROR_LOG, log)
+ xslt.xsltSetGenericErrorFunc(
+ NULL, _receiveXSLTError)
+ return 0
+
+ cdef int pop_error_log(self) except -1:
+ xmlerror.xmlSetStructuredErrorFunc(
+ self.old_error_context, self.old_error_func)
+ xslt.xsltSetGenericErrorFunc(
+ self.old_xslt_error_context, self.old_xslt_error_func)
+ _setThreadErrorLog(XSLT_ERROR_LOG, self.old_xslt_error_log)
+ self.old_xslt_error_log= None
+ return 0
+
+
+cdef class _ErrorLog(_ListErrorLog):
+ cdef list _logContexts
+ def __cinit__(self):
+ self._logContexts = []
+
+ def __init__(self):
+ _ListErrorLog.__init__(self, [], None, None)
+
+ @cython.final
+ cdef int __enter__(self) except -1:
+ return self.connect()
+
+ def __exit__(self, *args):
+ # TODO: make this a cdef function when Cython supports it
+ self.disconnect()
+
+ @cython.final
+ cdef int connect(self) except -1:
+ self._first_error = None
+ del self._entries[:]
+
+ cdef _ErrorLogContext context = _ErrorLogContext.__new__(_ErrorLogContext)
+ context.push_error_log(self)
+ self._logContexts.append(context)
+ return 0
+
+ @cython.final
+ cdef int disconnect(self) except -1:
+ cdef _ErrorLogContext context = self._logContexts.pop()
+ context.pop_error_log()
+ return 0
+
+ cpdef clear(self):
+ self._first_error = None
+ self.last_error = None
+ self._offset = 0
+ del self._entries[:]
+
+ cpdef copy(self):
+ """Creates a shallow copy of this error log and the list of entries.
+ """
+ return _ListErrorLog(
+ self._entries[self._offset:],
+ self._first_error, self.last_error)
+
+ def __iter__(self):
+ return iter(self._entries[self._offset:])
+
+ cpdef receive(self, _LogEntry entry):
+ if self._first_error is None and entry.level >= xmlerror.XML_ERR_ERROR:
+ self._first_error = entry
+ self._entries.append(entry)
+
+cdef class _DomainErrorLog(_ErrorLog):
+ def __init__(self, domains):
+ _ErrorLog.__init__(self)
+ self._accepted_domains = tuple(domains)
+
+ cpdef receive(self, _LogEntry entry):
+ if entry.domain in self._accepted_domains:
+ _ErrorLog.receive(self, entry)
+
+cdef class _RotatingErrorLog(_ErrorLog):
+ cdef int _max_len
+ def __init__(self, max_len):
+ _ErrorLog.__init__(self)
+ self._max_len = max_len
+
+ cpdef receive(self, _LogEntry entry):
+ if self._first_error is None and entry.level >= xmlerror.XML_ERR_ERROR:
+ self._first_error = entry
+ self._entries.append(entry)
+
+ if len(self._entries) > self._max_len:
+ self._offset += 1
+ if self._offset > self._max_len // 3:
+ offset = self._offset
+ self._offset = 0
+ del self._entries[:offset]
+
+cdef class PyErrorLog(_BaseErrorLog):
+ """PyErrorLog(self, logger_name=None, logger=None)
+ A global error log that connects to the Python stdlib logging package.
+
+ The constructor accepts an optional logger name or a readily
+ instantiated logger instance.
+
+ If you want to change the mapping between libxml2's ErrorLevels and Python
+ logging levels, you can modify the level_map dictionary from a subclass.
+
+ The default mapping is::
+
+ ErrorLevels.WARNING = logging.WARNING
+ ErrorLevels.ERROR = logging.ERROR
+ ErrorLevels.FATAL = logging.CRITICAL
+
+ You can also override the method ``receive()`` that takes a LogEntry
+ object and calls ``self.log(log_entry, format_string, arg1, arg2, ...)``
+ with appropriate data.
+ """
+ cdef readonly dict level_map
+ cdef object _map_level
+ cdef object _log
+ def __init__(self, logger_name=None, logger=None):
+ _BaseErrorLog.__init__(self, None, None)
+ import logging
+ self.level_map = {
+ ErrorLevels.WARNING : logging.WARNING,
+ ErrorLevels.ERROR : logging.ERROR,
+ ErrorLevels.FATAL : logging.CRITICAL
+ }
+ self._map_level = self.level_map.get
+ if logger is None:
+ if logger_name:
+ logger = logging.getLogger(logger_name)
+ else:
+ logger = logging.getLogger()
+ self._log = logger.log
+
+ cpdef copy(self):
+ """Dummy method that returns an empty error log.
+ """
+ return _ListErrorLog([], None, None)
+
+ def log(self, log_entry, message, *args):
+ """log(self, log_entry, message, *args)
+
+ Called by the .receive() method to log a _LogEntry instance to
+ the Python logging system. This handles the error level
+ mapping.
+
+ In the default implementation, the ``message`` argument
+ receives a complete log line, and there are no further
+ ``args``. To change the message format, it is best to
+ override the .receive() method instead of this one.
+ """
+ self._log(
+ self._map_level(log_entry.level, 0),
+ message, *args
+ )
+
+ cpdef receive(self, _LogEntry log_entry):
+ """receive(self, log_entry)
+
+ Receive a _LogEntry instance from the logging system. Calls
+ the .log() method with appropriate parameters::
+
+ self.log(log_entry, repr(log_entry))
+
+ You can override this method to provide your own log output
+ format.
+ """
+ self.log(log_entry, repr(log_entry))
+
+# thread-local, global list log to collect error output messages from
+# libxml2/libxslt
+
+cdef _BaseErrorLog __GLOBAL_ERROR_LOG = _RotatingErrorLog(__MAX_LOG_SIZE)
+
+
+cdef _BaseErrorLog _getThreadErrorLog(name):
+ """Retrieve the current error log with name 'name' of this thread."""
+ cdef python.PyObject* thread_dict
+ thread_dict = python.PyThreadState_GetDict()
+ if thread_dict is NULL:
+ return __GLOBAL_ERROR_LOG
+ try:
+ return (thread_dict)[name]
+ except KeyError:
+ log = (thread_dict)[name] = \
+ _RotatingErrorLog(__MAX_LOG_SIZE)
+ return log
+
+
+cdef _setThreadErrorLog(name, _BaseErrorLog log):
+ """Set the global error log of this thread."""
+ cdef python.PyObject* thread_dict
+ thread_dict = python.PyThreadState_GetDict()
+ if thread_dict is NULL:
+ if name == GLOBAL_ERROR_LOG:
+ global __GLOBAL_ERROR_LOG
+ __GLOBAL_ERROR_LOG = log
+ else:
+ (thread_dict)[name] = log
+
+
+cdef __copyGlobalErrorLog():
+ "Helper function for properties in exceptions."
+ return _getThreadErrorLog(GLOBAL_ERROR_LOG).copy()
+
+
+def use_global_python_log(PyErrorLog log not None):
+ """use_global_python_log(log)
+
+ Replace the global error log by an etree.PyErrorLog that uses the
+ standard Python logging package.
+
+ Note that this disables access to the global error log from exceptions.
+ Parsers, XSLT etc. will continue to provide their normal local error log.
+
+ Note: prior to lxml 2.2, this changed the error log globally.
+ Since lxml 2.2, the global error log is local to a thread and this
+ function will only set the global error log of the current thread.
+ """
+ _setThreadErrorLog(GLOBAL_ERROR_LOG, log)
+
+
+# local log functions: forward error to logger object
+cdef void _forwardError(void* c_log_handler, const xmlerror.xmlError* error) noexcept with gil:
+ cdef _BaseErrorLog log_handler
+ if c_log_handler is not NULL:
+ log_handler = <_BaseErrorLog>c_log_handler
+ elif error.domain == xmlerror.XML_FROM_XSLT:
+ log_handler = _getThreadErrorLog(XSLT_ERROR_LOG)
+ else:
+ log_handler = _getThreadErrorLog(GLOBAL_ERROR_LOG)
+ log_handler._receive(error)
+
+
+cdef void _receiveError(void* c_log_handler, const xmlerror.xmlError* error) noexcept nogil:
+ # no Python objects here, may be called without thread context !
+ if __DEBUG:
+ _forwardError(c_log_handler, error)
+
+
+cdef void _receiveXSLTError(void* c_log_handler, char* msg, ...) noexcept nogil:
+ # no Python objects here, may be called without thread context !
+ cdef cvarargs.va_list args
+ cvarargs.va_start(args, msg)
+ _receiveGenericError(c_log_handler, xmlerror.XML_FROM_XSLT, msg, args)
+ cvarargs.va_end(args)
+
+cdef void _receiveRelaxNGParseError(void* c_log_handler, char* msg, ...) noexcept nogil:
+ # no Python objects here, may be called without thread context !
+ cdef cvarargs.va_list args
+ cvarargs.va_start(args, msg)
+ _receiveGenericError(c_log_handler, xmlerror.XML_FROM_RELAXNGP, msg, args)
+ cvarargs.va_end(args)
+
+cdef void _receiveRelaxNGValidationError(void* c_log_handler, char* msg, ...) noexcept nogil:
+ # no Python objects here, may be called without thread context !
+ cdef cvarargs.va_list args
+ cvarargs.va_start(args, msg)
+ _receiveGenericError(c_log_handler, xmlerror.XML_FROM_RELAXNGV, msg, args)
+ cvarargs.va_end(args)
+
+# dummy function: no log output at all
+cdef void _nullGenericErrorFunc(void* ctxt, char* msg, ...) noexcept nogil:
+ pass
+
+
+cdef void _connectGenericErrorLog(log, int c_domain=-1) noexcept:
+ cdef xmlerror.xmlGenericErrorFunc error_func = NULL
+ c_log = log
+ if c_domain == xmlerror.XML_FROM_XSLT:
+ error_func = _receiveXSLTError
+ elif c_domain == xmlerror.XML_FROM_RELAXNGP:
+ error_func = _receiveRelaxNGParseError
+ elif c_domain == xmlerror.XML_FROM_RELAXNGV:
+ error_func = _receiveRelaxNGValidationError
+
+ if log is None or error_func is NULL:
+ c_log = NULL
+ error_func = _nullGenericErrorFunc
+ xmlerror.xmlSetGenericErrorFunc(c_log, error_func)
+
+
+cdef void _receiveGenericError(void* c_log_handler, int c_domain,
+ char* msg, cvarargs.va_list args) noexcept nogil:
+ # no Python objects here, may be called without thread context !
+ cdef xmlerror.xmlError c_error
+ cdef char* c_text
+ cdef char* c_message
+ cdef char* c_element
+ cdef char* c_pos
+ cdef char* c_name_pos
+ cdef char* c_str
+ cdef int text_size, element_size, format_count, c_int
+ if not __DEBUG or msg is NULL:
+ return
+ if msg[0] in b'\n\0':
+ return
+
+ c_text = c_element = c_error.file = c_error.node = NULL
+ c_error.line = 0
+
+ # parse "NAME %s" chunks from the format string
+ c_name_pos = c_pos = msg
+ format_count = 0
+ while c_pos[0]:
+ if c_pos[0] == b'%':
+ c_pos += 1
+ if c_pos[0] == b's': # "%s"
+ format_count += 1
+ c_str = cvarargs.va_charptr(args)
+ if c_pos == msg + 1:
+ c_text = c_str # msg == "%s..."
+ elif c_name_pos[0] == b'e':
+ if cstring_h.strncmp(c_name_pos, 'element %s', 10) == 0:
+ c_element = c_str
+ elif c_name_pos[0] == b'f':
+ if cstring_h.strncmp(c_name_pos, 'file %s', 7) == 0:
+ if cstring_h.strncmp('string://__STRING__XSLT',
+ c_str, 23) == 0:
+ c_str = ''
+ c_error.file = c_str
+ elif c_pos[0] == b'd': # "%d"
+ format_count += 1
+ c_int = cvarargs.va_int(args)
+ if cstring_h.strncmp(c_name_pos, 'line %d', 7) == 0:
+ c_error.line = c_int
+ elif c_pos[0] != b'%': # "%%" == "%"
+ format_count += 1
+ break # unexpected format or end of string => abort
+ elif c_pos[0] == b' ':
+ if c_pos[1] != b'%':
+ c_name_pos = c_pos + 1
+ c_pos += 1
+
+ c_message = NULL
+ if c_text is NULL:
+ if c_element is not NULL and format_count == 1:
+ # special case: a single occurrence of 'element %s'
+ text_size = cstring_h.strlen(msg)
+ element_size = cstring_h.strlen(c_element)
+ c_message = stdlib.malloc(
+ (text_size + element_size + 1) * sizeof(char))
+ stdio.sprintf(c_message, msg, c_element)
+ c_error.message = c_message
+ else:
+ c_error.message = ''
+ elif c_element is NULL:
+ c_error.message = c_text
+ else:
+ text_size = cstring_h.strlen(c_text)
+ element_size = cstring_h.strlen(c_element)
+ c_message = stdlib.malloc(
+ (text_size + 12 + element_size + 1) * sizeof(char))
+ if c_message is NULL:
+ c_error.message = c_text
+ else:
+ stdio.sprintf(c_message, "%s, element '%s'", c_text, c_element)
+ c_error.message = c_message
+
+ c_error.domain = c_domain
+ c_error.code = xmlerror.XML_ERR_OK # what else?
+ c_error.level = xmlerror.XML_ERR_ERROR # what else?
+ c_error.int2 = 0
+
+ _forwardError(c_log_handler, &c_error)
+
+ if c_message is not NULL:
+ stdlib.free(c_message)
+
+################################################################################
+## CONSTANTS FROM "xmlerror.h" (or rather libxml-xmlerror.html)
+################################################################################
+
+cdef __initErrorConstants():
+ "Called at setup time to parse the constants and build the classes below."
+ global __ERROR_LEVELS, __ERROR_DOMAINS, __PARSER_ERROR_TYPES, __RELAXNG_ERROR_TYPES
+ const_defs = ((ErrorLevels, __ERROR_LEVELS),
+ (ErrorDomains, __ERROR_DOMAINS),
+ (ErrorTypes, __PARSER_ERROR_TYPES),
+ (RelaxNGErrorTypes, __RELAXNG_ERROR_TYPES))
+
+ for cls, constants in const_defs:
+ reverse_dict = {}
+ cls._names = reverse_dict
+ cls._getName = reverse_dict.get
+ for line in constants.splitlines():
+ if not line:
+ continue
+ name, value = line.split('=')
+ value = int(value)
+ setattr(cls, name, value)
+ reverse_dict[value] = name
+
+ # discard the global string references after use
+ __ERROR_LEVELS = __ERROR_DOMAINS = __PARSER_ERROR_TYPES = __RELAXNG_ERROR_TYPES = None
+
+
+class ErrorLevels(object):
+ """Libxml2 error levels"""
+
+class ErrorDomains(object):
+ """Libxml2 error domains"""
+
+class ErrorTypes(object):
+ """Libxml2 error types"""
+
+class RelaxNGErrorTypes(object):
+ """Libxml2 RelaxNG error types"""
+
+
+# --- BEGIN: GENERATED CONSTANTS ---
+
+# This section is generated by the script 'update-error-constants.py'.
+
+cdef object __ERROR_LEVELS = """\
+NONE=0
+WARNING=1
+ERROR=2
+FATAL=3
+"""
+
+cdef object __ERROR_DOMAINS = """\
+NONE=0
+PARSER=1
+TREE=2
+NAMESPACE=3
+DTD=4
+HTML=5
+MEMORY=6
+OUTPUT=7
+IO=8
+FTP=9
+HTTP=10
+XINCLUDE=11
+XPATH=12
+XPOINTER=13
+REGEXP=14
+DATATYPE=15
+SCHEMASP=16
+SCHEMASV=17
+RELAXNGP=18
+RELAXNGV=19
+CATALOG=20
+C14N=21
+XSLT=22
+VALID=23
+CHECK=24
+WRITER=25
+MODULE=26
+I18N=27
+SCHEMATRONV=28
+BUFFER=29
+URI=30
+"""
+
+cdef object __PARSER_ERROR_TYPES = """\
+ERR_OK=0
+ERR_INTERNAL_ERROR=1
+ERR_NO_MEMORY=2
+ERR_DOCUMENT_START=3
+ERR_DOCUMENT_EMPTY=4
+ERR_DOCUMENT_END=5
+ERR_INVALID_HEX_CHARREF=6
+ERR_INVALID_DEC_CHARREF=7
+ERR_INVALID_CHARREF=8
+ERR_INVALID_CHAR=9
+ERR_CHARREF_AT_EOF=10
+ERR_CHARREF_IN_PROLOG=11
+ERR_CHARREF_IN_EPILOG=12
+ERR_CHARREF_IN_DTD=13
+ERR_ENTITYREF_AT_EOF=14
+ERR_ENTITYREF_IN_PROLOG=15
+ERR_ENTITYREF_IN_EPILOG=16
+ERR_ENTITYREF_IN_DTD=17
+ERR_PEREF_AT_EOF=18
+ERR_PEREF_IN_PROLOG=19
+ERR_PEREF_IN_EPILOG=20
+ERR_PEREF_IN_INT_SUBSET=21
+ERR_ENTITYREF_NO_NAME=22
+ERR_ENTITYREF_SEMICOL_MISSING=23
+ERR_PEREF_NO_NAME=24
+ERR_PEREF_SEMICOL_MISSING=25
+ERR_UNDECLARED_ENTITY=26
+WAR_UNDECLARED_ENTITY=27
+ERR_UNPARSED_ENTITY=28
+ERR_ENTITY_IS_EXTERNAL=29
+ERR_ENTITY_IS_PARAMETER=30
+ERR_UNKNOWN_ENCODING=31
+ERR_UNSUPPORTED_ENCODING=32
+ERR_STRING_NOT_STARTED=33
+ERR_STRING_NOT_CLOSED=34
+ERR_NS_DECL_ERROR=35
+ERR_ENTITY_NOT_STARTED=36
+ERR_ENTITY_NOT_FINISHED=37
+ERR_LT_IN_ATTRIBUTE=38
+ERR_ATTRIBUTE_NOT_STARTED=39
+ERR_ATTRIBUTE_NOT_FINISHED=40
+ERR_ATTRIBUTE_WITHOUT_VALUE=41
+ERR_ATTRIBUTE_REDEFINED=42
+ERR_LITERAL_NOT_STARTED=43
+ERR_LITERAL_NOT_FINISHED=44
+ERR_COMMENT_NOT_FINISHED=45
+ERR_PI_NOT_STARTED=46
+ERR_PI_NOT_FINISHED=47
+ERR_NOTATION_NOT_STARTED=48
+ERR_NOTATION_NOT_FINISHED=49
+ERR_ATTLIST_NOT_STARTED=50
+ERR_ATTLIST_NOT_FINISHED=51
+ERR_MIXED_NOT_STARTED=52
+ERR_MIXED_NOT_FINISHED=53
+ERR_ELEMCONTENT_NOT_STARTED=54
+ERR_ELEMCONTENT_NOT_FINISHED=55
+ERR_XMLDECL_NOT_STARTED=56
+ERR_XMLDECL_NOT_FINISHED=57
+ERR_CONDSEC_NOT_STARTED=58
+ERR_CONDSEC_NOT_FINISHED=59
+ERR_EXT_SUBSET_NOT_FINISHED=60
+ERR_DOCTYPE_NOT_FINISHED=61
+ERR_MISPLACED_CDATA_END=62
+ERR_CDATA_NOT_FINISHED=63
+ERR_RESERVED_XML_NAME=64
+ERR_SPACE_REQUIRED=65
+ERR_SEPARATOR_REQUIRED=66
+ERR_NMTOKEN_REQUIRED=67
+ERR_NAME_REQUIRED=68
+ERR_PCDATA_REQUIRED=69
+ERR_URI_REQUIRED=70
+ERR_PUBID_REQUIRED=71
+ERR_LT_REQUIRED=72
+ERR_GT_REQUIRED=73
+ERR_LTSLASH_REQUIRED=74
+ERR_EQUAL_REQUIRED=75
+ERR_TAG_NAME_MISMATCH=76
+ERR_TAG_NOT_FINISHED=77
+ERR_STANDALONE_VALUE=78
+ERR_ENCODING_NAME=79
+ERR_HYPHEN_IN_COMMENT=80
+ERR_INVALID_ENCODING=81
+ERR_EXT_ENTITY_STANDALONE=82
+ERR_CONDSEC_INVALID=83
+ERR_VALUE_REQUIRED=84
+ERR_NOT_WELL_BALANCED=85
+ERR_EXTRA_CONTENT=86
+ERR_ENTITY_CHAR_ERROR=87
+ERR_ENTITY_PE_INTERNAL=88
+ERR_ENTITY_LOOP=89
+ERR_ENTITY_BOUNDARY=90
+ERR_INVALID_URI=91
+ERR_URI_FRAGMENT=92
+WAR_CATALOG_PI=93
+ERR_NO_DTD=94
+ERR_CONDSEC_INVALID_KEYWORD=95
+ERR_VERSION_MISSING=96
+WAR_UNKNOWN_VERSION=97
+WAR_LANG_VALUE=98
+WAR_NS_URI=99
+WAR_NS_URI_RELATIVE=100
+ERR_MISSING_ENCODING=101
+WAR_SPACE_VALUE=102
+ERR_NOT_STANDALONE=103
+ERR_ENTITY_PROCESSING=104
+ERR_NOTATION_PROCESSING=105
+WAR_NS_COLUMN=106
+WAR_ENTITY_REDEFINED=107
+ERR_UNKNOWN_VERSION=108
+ERR_VERSION_MISMATCH=109
+ERR_NAME_TOO_LONG=110
+ERR_USER_STOP=111
+ERR_COMMENT_ABRUPTLY_ENDED=112
+WAR_ENCODING_MISMATCH=113
+ERR_RESOURCE_LIMIT=114
+ERR_ARGUMENT=115
+ERR_SYSTEM=116
+ERR_REDECL_PREDEF_ENTITY=117
+ERR_INT_SUBSET_NOT_FINISHED=118
+NS_ERR_XML_NAMESPACE=200
+NS_ERR_UNDEFINED_NAMESPACE=201
+NS_ERR_QNAME=202
+NS_ERR_ATTRIBUTE_REDEFINED=203
+NS_ERR_EMPTY=204
+NS_ERR_COLON=205
+DTD_ATTRIBUTE_DEFAULT=500
+DTD_ATTRIBUTE_REDEFINED=501
+DTD_ATTRIBUTE_VALUE=502
+DTD_CONTENT_ERROR=503
+DTD_CONTENT_MODEL=504
+DTD_CONTENT_NOT_DETERMINIST=505
+DTD_DIFFERENT_PREFIX=506
+DTD_ELEM_DEFAULT_NAMESPACE=507
+DTD_ELEM_NAMESPACE=508
+DTD_ELEM_REDEFINED=509
+DTD_EMPTY_NOTATION=510
+DTD_ENTITY_TYPE=511
+DTD_ID_FIXED=512
+DTD_ID_REDEFINED=513
+DTD_ID_SUBSET=514
+DTD_INVALID_CHILD=515
+DTD_INVALID_DEFAULT=516
+DTD_LOAD_ERROR=517
+DTD_MISSING_ATTRIBUTE=518
+DTD_MIXED_CORRUPT=519
+DTD_MULTIPLE_ID=520
+DTD_NO_DOC=521
+DTD_NO_DTD=522
+DTD_NO_ELEM_NAME=523
+DTD_NO_PREFIX=524
+DTD_NO_ROOT=525
+DTD_NOTATION_REDEFINED=526
+DTD_NOTATION_VALUE=527
+DTD_NOT_EMPTY=528
+DTD_NOT_PCDATA=529
+DTD_NOT_STANDALONE=530
+DTD_ROOT_NAME=531
+DTD_STANDALONE_WHITE_SPACE=532
+DTD_UNKNOWN_ATTRIBUTE=533
+DTD_UNKNOWN_ELEM=534
+DTD_UNKNOWN_ENTITY=535
+DTD_UNKNOWN_ID=536
+DTD_UNKNOWN_NOTATION=537
+DTD_STANDALONE_DEFAULTED=538
+DTD_XMLID_VALUE=539
+DTD_XMLID_TYPE=540
+DTD_DUP_TOKEN=541
+HTML_STRUCURE_ERROR=800
+HTML_UNKNOWN_TAG=801
+HTML_INCORRECTLY_OPENED_COMMENT=802
+RNGP_ANYNAME_ATTR_ANCESTOR=1000
+RNGP_ATTR_CONFLICT=1001
+RNGP_ATTRIBUTE_CHILDREN=1002
+RNGP_ATTRIBUTE_CONTENT=1003
+RNGP_ATTRIBUTE_EMPTY=1004
+RNGP_ATTRIBUTE_NOOP=1005
+RNGP_CHOICE_CONTENT=1006
+RNGP_CHOICE_EMPTY=1007
+RNGP_CREATE_FAILURE=1008
+RNGP_DATA_CONTENT=1009
+RNGP_DEF_CHOICE_AND_INTERLEAVE=1010
+RNGP_DEFINE_CREATE_FAILED=1011
+RNGP_DEFINE_EMPTY=1012
+RNGP_DEFINE_MISSING=1013
+RNGP_DEFINE_NAME_MISSING=1014
+RNGP_ELEM_CONTENT_EMPTY=1015
+RNGP_ELEM_CONTENT_ERROR=1016
+RNGP_ELEMENT_EMPTY=1017
+RNGP_ELEMENT_CONTENT=1018
+RNGP_ELEMENT_NAME=1019
+RNGP_ELEMENT_NO_CONTENT=1020
+RNGP_ELEM_TEXT_CONFLICT=1021
+RNGP_EMPTY=1022
+RNGP_EMPTY_CONSTRUCT=1023
+RNGP_EMPTY_CONTENT=1024
+RNGP_EMPTY_NOT_EMPTY=1025
+RNGP_ERROR_TYPE_LIB=1026
+RNGP_EXCEPT_EMPTY=1027
+RNGP_EXCEPT_MISSING=1028
+RNGP_EXCEPT_MULTIPLE=1029
+RNGP_EXCEPT_NO_CONTENT=1030
+RNGP_EXTERNALREF_EMTPY=1031
+RNGP_EXTERNAL_REF_FAILURE=1032
+RNGP_EXTERNALREF_RECURSE=1033
+RNGP_FORBIDDEN_ATTRIBUTE=1034
+RNGP_FOREIGN_ELEMENT=1035
+RNGP_GRAMMAR_CONTENT=1036
+RNGP_GRAMMAR_EMPTY=1037
+RNGP_GRAMMAR_MISSING=1038
+RNGP_GRAMMAR_NO_START=1039
+RNGP_GROUP_ATTR_CONFLICT=1040
+RNGP_HREF_ERROR=1041
+RNGP_INCLUDE_EMPTY=1042
+RNGP_INCLUDE_FAILURE=1043
+RNGP_INCLUDE_RECURSE=1044
+RNGP_INTERLEAVE_ADD=1045
+RNGP_INTERLEAVE_CREATE_FAILED=1046
+RNGP_INTERLEAVE_EMPTY=1047
+RNGP_INTERLEAVE_NO_CONTENT=1048
+RNGP_INVALID_DEFINE_NAME=1049
+RNGP_INVALID_URI=1050
+RNGP_INVALID_VALUE=1051
+RNGP_MISSING_HREF=1052
+RNGP_NAME_MISSING=1053
+RNGP_NEED_COMBINE=1054
+RNGP_NOTALLOWED_NOT_EMPTY=1055
+RNGP_NSNAME_ATTR_ANCESTOR=1056
+RNGP_NSNAME_NO_NS=1057
+RNGP_PARAM_FORBIDDEN=1058
+RNGP_PARAM_NAME_MISSING=1059
+RNGP_PARENTREF_CREATE_FAILED=1060
+RNGP_PARENTREF_NAME_INVALID=1061
+RNGP_PARENTREF_NO_NAME=1062
+RNGP_PARENTREF_NO_PARENT=1063
+RNGP_PARENTREF_NOT_EMPTY=1064
+RNGP_PARSE_ERROR=1065
+RNGP_PAT_ANYNAME_EXCEPT_ANYNAME=1066
+RNGP_PAT_ATTR_ATTR=1067
+RNGP_PAT_ATTR_ELEM=1068
+RNGP_PAT_DATA_EXCEPT_ATTR=1069
+RNGP_PAT_DATA_EXCEPT_ELEM=1070
+RNGP_PAT_DATA_EXCEPT_EMPTY=1071
+RNGP_PAT_DATA_EXCEPT_GROUP=1072
+RNGP_PAT_DATA_EXCEPT_INTERLEAVE=1073
+RNGP_PAT_DATA_EXCEPT_LIST=1074
+RNGP_PAT_DATA_EXCEPT_ONEMORE=1075
+RNGP_PAT_DATA_EXCEPT_REF=1076
+RNGP_PAT_DATA_EXCEPT_TEXT=1077
+RNGP_PAT_LIST_ATTR=1078
+RNGP_PAT_LIST_ELEM=1079
+RNGP_PAT_LIST_INTERLEAVE=1080
+RNGP_PAT_LIST_LIST=1081
+RNGP_PAT_LIST_REF=1082
+RNGP_PAT_LIST_TEXT=1083
+RNGP_PAT_NSNAME_EXCEPT_ANYNAME=1084
+RNGP_PAT_NSNAME_EXCEPT_NSNAME=1085
+RNGP_PAT_ONEMORE_GROUP_ATTR=1086
+RNGP_PAT_ONEMORE_INTERLEAVE_ATTR=1087
+RNGP_PAT_START_ATTR=1088
+RNGP_PAT_START_DATA=1089
+RNGP_PAT_START_EMPTY=1090
+RNGP_PAT_START_GROUP=1091
+RNGP_PAT_START_INTERLEAVE=1092
+RNGP_PAT_START_LIST=1093
+RNGP_PAT_START_ONEMORE=1094
+RNGP_PAT_START_TEXT=1095
+RNGP_PAT_START_VALUE=1096
+RNGP_PREFIX_UNDEFINED=1097
+RNGP_REF_CREATE_FAILED=1098
+RNGP_REF_CYCLE=1099
+RNGP_REF_NAME_INVALID=1100
+RNGP_REF_NO_DEF=1101
+RNGP_REF_NO_NAME=1102
+RNGP_REF_NOT_EMPTY=1103
+RNGP_START_CHOICE_AND_INTERLEAVE=1104
+RNGP_START_CONTENT=1105
+RNGP_START_EMPTY=1106
+RNGP_START_MISSING=1107
+RNGP_TEXT_EXPECTED=1108
+RNGP_TEXT_HAS_CHILD=1109
+RNGP_TYPE_MISSING=1110
+RNGP_TYPE_NOT_FOUND=1111
+RNGP_TYPE_VALUE=1112
+RNGP_UNKNOWN_ATTRIBUTE=1113
+RNGP_UNKNOWN_COMBINE=1114
+RNGP_UNKNOWN_CONSTRUCT=1115
+RNGP_UNKNOWN_TYPE_LIB=1116
+RNGP_URI_FRAGMENT=1117
+RNGP_URI_NOT_ABSOLUTE=1118
+RNGP_VALUE_EMPTY=1119
+RNGP_VALUE_NO_CONTENT=1120
+RNGP_XMLNS_NAME=1121
+RNGP_XML_NS=1122
+XPATH_EXPRESSION_OK=1200
+XPATH_NUMBER_ERROR=1201
+XPATH_UNFINISHED_LITERAL_ERROR=1202
+XPATH_START_LITERAL_ERROR=1203
+XPATH_VARIABLE_REF_ERROR=1204
+XPATH_UNDEF_VARIABLE_ERROR=1205
+XPATH_INVALID_PREDICATE_ERROR=1206
+XPATH_EXPR_ERROR=1207
+XPATH_UNCLOSED_ERROR=1208
+XPATH_UNKNOWN_FUNC_ERROR=1209
+XPATH_INVALID_OPERAND=1210
+XPATH_INVALID_TYPE=1211
+XPATH_INVALID_ARITY=1212
+XPATH_INVALID_CTXT_SIZE=1213
+XPATH_INVALID_CTXT_POSITION=1214
+XPATH_MEMORY_ERROR=1215
+XPTR_SYNTAX_ERROR=1216
+XPTR_RESOURCE_ERROR=1217
+XPTR_SUB_RESOURCE_ERROR=1218
+XPATH_UNDEF_PREFIX_ERROR=1219
+XPATH_ENCODING_ERROR=1220
+XPATH_INVALID_CHAR_ERROR=1221
+TREE_INVALID_HEX=1300
+TREE_INVALID_DEC=1301
+TREE_UNTERMINATED_ENTITY=1302
+TREE_NOT_UTF8=1303
+SAVE_NOT_UTF8=1400
+SAVE_CHAR_INVALID=1401
+SAVE_NO_DOCTYPE=1402
+SAVE_UNKNOWN_ENCODING=1403
+REGEXP_COMPILE_ERROR=1450
+IO_UNKNOWN=1500
+IO_EACCES=1501
+IO_EAGAIN=1502
+IO_EBADF=1503
+IO_EBADMSG=1504
+IO_EBUSY=1505
+IO_ECANCELED=1506
+IO_ECHILD=1507
+IO_EDEADLK=1508
+IO_EDOM=1509
+IO_EEXIST=1510
+IO_EFAULT=1511
+IO_EFBIG=1512
+IO_EINPROGRESS=1513
+IO_EINTR=1514
+IO_EINVAL=1515
+IO_EIO=1516
+IO_EISDIR=1517
+IO_EMFILE=1518
+IO_EMLINK=1519
+IO_EMSGSIZE=1520
+IO_ENAMETOOLONG=1521
+IO_ENFILE=1522
+IO_ENODEV=1523
+IO_ENOENT=1524
+IO_ENOEXEC=1525
+IO_ENOLCK=1526
+IO_ENOMEM=1527
+IO_ENOSPC=1528
+IO_ENOSYS=1529
+IO_ENOTDIR=1530
+IO_ENOTEMPTY=1531
+IO_ENOTSUP=1532
+IO_ENOTTY=1533
+IO_ENXIO=1534
+IO_EPERM=1535
+IO_EPIPE=1536
+IO_ERANGE=1537
+IO_EROFS=1538
+IO_ESPIPE=1539
+IO_ESRCH=1540
+IO_ETIMEDOUT=1541
+IO_EXDEV=1542
+IO_NETWORK_ATTEMPT=1543
+IO_ENCODER=1544
+IO_FLUSH=1545
+IO_WRITE=1546
+IO_NO_INPUT=1547
+IO_BUFFER_FULL=1548
+IO_LOAD_ERROR=1549
+IO_ENOTSOCK=1550
+IO_EISCONN=1551
+IO_ECONNREFUSED=1552
+IO_ENETUNREACH=1553
+IO_EADDRINUSE=1554
+IO_EALREADY=1555
+IO_EAFNOSUPPORT=1556
+IO_UNSUPPORTED_PROTOCOL=1557
+XINCLUDE_RECURSION=1600
+XINCLUDE_PARSE_VALUE=1601
+XINCLUDE_ENTITY_DEF_MISMATCH=1602
+XINCLUDE_NO_HREF=1603
+XINCLUDE_NO_FALLBACK=1604
+XINCLUDE_HREF_URI=1605
+XINCLUDE_TEXT_FRAGMENT=1606
+XINCLUDE_TEXT_DOCUMENT=1607
+XINCLUDE_INVALID_CHAR=1608
+XINCLUDE_BUILD_FAILED=1609
+XINCLUDE_UNKNOWN_ENCODING=1610
+XINCLUDE_MULTIPLE_ROOT=1611
+XINCLUDE_XPTR_FAILED=1612
+XINCLUDE_XPTR_RESULT=1613
+XINCLUDE_INCLUDE_IN_INCLUDE=1614
+XINCLUDE_FALLBACKS_IN_INCLUDE=1615
+XINCLUDE_FALLBACK_NOT_IN_INCLUDE=1616
+XINCLUDE_DEPRECATED_NS=1617
+XINCLUDE_FRAGMENT_ID=1618
+CATALOG_MISSING_ATTR=1650
+CATALOG_ENTRY_BROKEN=1651
+CATALOG_PREFER_VALUE=1652
+CATALOG_NOT_CATALOG=1653
+CATALOG_RECURSION=1654
+SCHEMAP_PREFIX_UNDEFINED=1700
+SCHEMAP_ATTRFORMDEFAULT_VALUE=1701
+SCHEMAP_ATTRGRP_NONAME_NOREF=1702
+SCHEMAP_ATTR_NONAME_NOREF=1703
+SCHEMAP_COMPLEXTYPE_NONAME_NOREF=1704
+SCHEMAP_ELEMFORMDEFAULT_VALUE=1705
+SCHEMAP_ELEM_NONAME_NOREF=1706
+SCHEMAP_EXTENSION_NO_BASE=1707
+SCHEMAP_FACET_NO_VALUE=1708
+SCHEMAP_FAILED_BUILD_IMPORT=1709
+SCHEMAP_GROUP_NONAME_NOREF=1710
+SCHEMAP_IMPORT_NAMESPACE_NOT_URI=1711
+SCHEMAP_IMPORT_REDEFINE_NSNAME=1712
+SCHEMAP_IMPORT_SCHEMA_NOT_URI=1713
+SCHEMAP_INVALID_BOOLEAN=1714
+SCHEMAP_INVALID_ENUM=1715
+SCHEMAP_INVALID_FACET=1716
+SCHEMAP_INVALID_FACET_VALUE=1717
+SCHEMAP_INVALID_MAXOCCURS=1718
+SCHEMAP_INVALID_MINOCCURS=1719
+SCHEMAP_INVALID_REF_AND_SUBTYPE=1720
+SCHEMAP_INVALID_WHITE_SPACE=1721
+SCHEMAP_NOATTR_NOREF=1722
+SCHEMAP_NOTATION_NO_NAME=1723
+SCHEMAP_NOTYPE_NOREF=1724
+SCHEMAP_REF_AND_SUBTYPE=1725
+SCHEMAP_RESTRICTION_NONAME_NOREF=1726
+SCHEMAP_SIMPLETYPE_NONAME=1727
+SCHEMAP_TYPE_AND_SUBTYPE=1728
+SCHEMAP_UNKNOWN_ALL_CHILD=1729
+SCHEMAP_UNKNOWN_ANYATTRIBUTE_CHILD=1730
+SCHEMAP_UNKNOWN_ATTR_CHILD=1731
+SCHEMAP_UNKNOWN_ATTRGRP_CHILD=1732
+SCHEMAP_UNKNOWN_ATTRIBUTE_GROUP=1733
+SCHEMAP_UNKNOWN_BASE_TYPE=1734
+SCHEMAP_UNKNOWN_CHOICE_CHILD=1735
+SCHEMAP_UNKNOWN_COMPLEXCONTENT_CHILD=1736
+SCHEMAP_UNKNOWN_COMPLEXTYPE_CHILD=1737
+SCHEMAP_UNKNOWN_ELEM_CHILD=1738
+SCHEMAP_UNKNOWN_EXTENSION_CHILD=1739
+SCHEMAP_UNKNOWN_FACET_CHILD=1740
+SCHEMAP_UNKNOWN_FACET_TYPE=1741
+SCHEMAP_UNKNOWN_GROUP_CHILD=1742
+SCHEMAP_UNKNOWN_IMPORT_CHILD=1743
+SCHEMAP_UNKNOWN_LIST_CHILD=1744
+SCHEMAP_UNKNOWN_NOTATION_CHILD=1745
+SCHEMAP_UNKNOWN_PROCESSCONTENT_CHILD=1746
+SCHEMAP_UNKNOWN_REF=1747
+SCHEMAP_UNKNOWN_RESTRICTION_CHILD=1748
+SCHEMAP_UNKNOWN_SCHEMAS_CHILD=1749
+SCHEMAP_UNKNOWN_SEQUENCE_CHILD=1750
+SCHEMAP_UNKNOWN_SIMPLECONTENT_CHILD=1751
+SCHEMAP_UNKNOWN_SIMPLETYPE_CHILD=1752
+SCHEMAP_UNKNOWN_TYPE=1753
+SCHEMAP_UNKNOWN_UNION_CHILD=1754
+SCHEMAP_ELEM_DEFAULT_FIXED=1755
+SCHEMAP_REGEXP_INVALID=1756
+SCHEMAP_FAILED_LOAD=1757
+SCHEMAP_NOTHING_TO_PARSE=1758
+SCHEMAP_NOROOT=1759
+SCHEMAP_REDEFINED_GROUP=1760
+SCHEMAP_REDEFINED_TYPE=1761
+SCHEMAP_REDEFINED_ELEMENT=1762
+SCHEMAP_REDEFINED_ATTRGROUP=1763
+SCHEMAP_REDEFINED_ATTR=1764
+SCHEMAP_REDEFINED_NOTATION=1765
+SCHEMAP_FAILED_PARSE=1766
+SCHEMAP_UNKNOWN_PREFIX=1767
+SCHEMAP_DEF_AND_PREFIX=1768
+SCHEMAP_UNKNOWN_INCLUDE_CHILD=1769
+SCHEMAP_INCLUDE_SCHEMA_NOT_URI=1770
+SCHEMAP_INCLUDE_SCHEMA_NO_URI=1771
+SCHEMAP_NOT_SCHEMA=1772
+SCHEMAP_UNKNOWN_MEMBER_TYPE=1773
+SCHEMAP_INVALID_ATTR_USE=1774
+SCHEMAP_RECURSIVE=1775
+SCHEMAP_SUPERNUMEROUS_LIST_ITEM_TYPE=1776
+SCHEMAP_INVALID_ATTR_COMBINATION=1777
+SCHEMAP_INVALID_ATTR_INLINE_COMBINATION=1778
+SCHEMAP_MISSING_SIMPLETYPE_CHILD=1779
+SCHEMAP_INVALID_ATTR_NAME=1780
+SCHEMAP_REF_AND_CONTENT=1781
+SCHEMAP_CT_PROPS_CORRECT_1=1782
+SCHEMAP_CT_PROPS_CORRECT_2=1783
+SCHEMAP_CT_PROPS_CORRECT_3=1784
+SCHEMAP_CT_PROPS_CORRECT_4=1785
+SCHEMAP_CT_PROPS_CORRECT_5=1786
+SCHEMAP_DERIVATION_OK_RESTRICTION_1=1787
+SCHEMAP_DERIVATION_OK_RESTRICTION_2_1_1=1788
+SCHEMAP_DERIVATION_OK_RESTRICTION_2_1_2=1789
+SCHEMAP_DERIVATION_OK_RESTRICTION_2_2=1790
+SCHEMAP_DERIVATION_OK_RESTRICTION_3=1791
+SCHEMAP_WILDCARD_INVALID_NS_MEMBER=1792
+SCHEMAP_INTERSECTION_NOT_EXPRESSIBLE=1793
+SCHEMAP_UNION_NOT_EXPRESSIBLE=1794
+SCHEMAP_SRC_IMPORT_3_1=1795
+SCHEMAP_SRC_IMPORT_3_2=1796
+SCHEMAP_DERIVATION_OK_RESTRICTION_4_1=1797
+SCHEMAP_DERIVATION_OK_RESTRICTION_4_2=1798
+SCHEMAP_DERIVATION_OK_RESTRICTION_4_3=1799
+SCHEMAP_COS_CT_EXTENDS_1_3=1800
+SCHEMAV_NOROOT=1801
+SCHEMAV_UNDECLAREDELEM=1802
+SCHEMAV_NOTTOPLEVEL=1803
+SCHEMAV_MISSING=1804
+SCHEMAV_WRONGELEM=1805
+SCHEMAV_NOTYPE=1806
+SCHEMAV_NOROLLBACK=1807
+SCHEMAV_ISABSTRACT=1808
+SCHEMAV_NOTEMPTY=1809
+SCHEMAV_ELEMCONT=1810
+SCHEMAV_HAVEDEFAULT=1811
+SCHEMAV_NOTNILLABLE=1812
+SCHEMAV_EXTRACONTENT=1813
+SCHEMAV_INVALIDATTR=1814
+SCHEMAV_INVALIDELEM=1815
+SCHEMAV_NOTDETERMINIST=1816
+SCHEMAV_CONSTRUCT=1817
+SCHEMAV_INTERNAL=1818
+SCHEMAV_NOTSIMPLE=1819
+SCHEMAV_ATTRUNKNOWN=1820
+SCHEMAV_ATTRINVALID=1821
+SCHEMAV_VALUE=1822
+SCHEMAV_FACET=1823
+SCHEMAV_CVC_DATATYPE_VALID_1_2_1=1824
+SCHEMAV_CVC_DATATYPE_VALID_1_2_2=1825
+SCHEMAV_CVC_DATATYPE_VALID_1_2_3=1826
+SCHEMAV_CVC_TYPE_3_1_1=1827
+SCHEMAV_CVC_TYPE_3_1_2=1828
+SCHEMAV_CVC_FACET_VALID=1829
+SCHEMAV_CVC_LENGTH_VALID=1830
+SCHEMAV_CVC_MINLENGTH_VALID=1831
+SCHEMAV_CVC_MAXLENGTH_VALID=1832
+SCHEMAV_CVC_MININCLUSIVE_VALID=1833
+SCHEMAV_CVC_MAXINCLUSIVE_VALID=1834
+SCHEMAV_CVC_MINEXCLUSIVE_VALID=1835
+SCHEMAV_CVC_MAXEXCLUSIVE_VALID=1836
+SCHEMAV_CVC_TOTALDIGITS_VALID=1837
+SCHEMAV_CVC_FRACTIONDIGITS_VALID=1838
+SCHEMAV_CVC_PATTERN_VALID=1839
+SCHEMAV_CVC_ENUMERATION_VALID=1840
+SCHEMAV_CVC_COMPLEX_TYPE_2_1=1841
+SCHEMAV_CVC_COMPLEX_TYPE_2_2=1842
+SCHEMAV_CVC_COMPLEX_TYPE_2_3=1843
+SCHEMAV_CVC_COMPLEX_TYPE_2_4=1844
+SCHEMAV_CVC_ELT_1=1845
+SCHEMAV_CVC_ELT_2=1846
+SCHEMAV_CVC_ELT_3_1=1847
+SCHEMAV_CVC_ELT_3_2_1=1848
+SCHEMAV_CVC_ELT_3_2_2=1849
+SCHEMAV_CVC_ELT_4_1=1850
+SCHEMAV_CVC_ELT_4_2=1851
+SCHEMAV_CVC_ELT_4_3=1852
+SCHEMAV_CVC_ELT_5_1_1=1853
+SCHEMAV_CVC_ELT_5_1_2=1854
+SCHEMAV_CVC_ELT_5_2_1=1855
+SCHEMAV_CVC_ELT_5_2_2_1=1856
+SCHEMAV_CVC_ELT_5_2_2_2_1=1857
+SCHEMAV_CVC_ELT_5_2_2_2_2=1858
+SCHEMAV_CVC_ELT_6=1859
+SCHEMAV_CVC_ELT_7=1860
+SCHEMAV_CVC_ATTRIBUTE_1=1861
+SCHEMAV_CVC_ATTRIBUTE_2=1862
+SCHEMAV_CVC_ATTRIBUTE_3=1863
+SCHEMAV_CVC_ATTRIBUTE_4=1864
+SCHEMAV_CVC_COMPLEX_TYPE_3_1=1865
+SCHEMAV_CVC_COMPLEX_TYPE_3_2_1=1866
+SCHEMAV_CVC_COMPLEX_TYPE_3_2_2=1867
+SCHEMAV_CVC_COMPLEX_TYPE_4=1868
+SCHEMAV_CVC_COMPLEX_TYPE_5_1=1869
+SCHEMAV_CVC_COMPLEX_TYPE_5_2=1870
+SCHEMAV_ELEMENT_CONTENT=1871
+SCHEMAV_DOCUMENT_ELEMENT_MISSING=1872
+SCHEMAV_CVC_COMPLEX_TYPE_1=1873
+SCHEMAV_CVC_AU=1874
+SCHEMAV_CVC_TYPE_1=1875
+SCHEMAV_CVC_TYPE_2=1876
+SCHEMAV_CVC_IDC=1877
+SCHEMAV_CVC_WILDCARD=1878
+SCHEMAV_MISC=1879
+XPTR_UNKNOWN_SCHEME=1900
+XPTR_CHILDSEQ_START=1901
+XPTR_EVAL_FAILED=1902
+XPTR_EXTRA_OBJECTS=1903
+C14N_CREATE_CTXT=1950
+C14N_REQUIRES_UTF8=1951
+C14N_CREATE_STACK=1952
+C14N_INVALID_NODE=1953
+C14N_UNKNOW_NODE=1954
+C14N_RELATIVE_NAMESPACE=1955
+FTP_PASV_ANSWER=2000
+FTP_EPSV_ANSWER=2001
+FTP_ACCNT=2002
+FTP_URL_SYNTAX=2003
+HTTP_URL_SYNTAX=2020
+HTTP_USE_IP=2021
+HTTP_UNKNOWN_HOST=2022
+SCHEMAP_SRC_SIMPLE_TYPE_1=3000
+SCHEMAP_SRC_SIMPLE_TYPE_2=3001
+SCHEMAP_SRC_SIMPLE_TYPE_3=3002
+SCHEMAP_SRC_SIMPLE_TYPE_4=3003
+SCHEMAP_SRC_RESOLVE=3004
+SCHEMAP_SRC_RESTRICTION_BASE_OR_SIMPLETYPE=3005
+SCHEMAP_SRC_LIST_ITEMTYPE_OR_SIMPLETYPE=3006
+SCHEMAP_SRC_UNION_MEMBERTYPES_OR_SIMPLETYPES=3007
+SCHEMAP_ST_PROPS_CORRECT_1=3008
+SCHEMAP_ST_PROPS_CORRECT_2=3009
+SCHEMAP_ST_PROPS_CORRECT_3=3010
+SCHEMAP_COS_ST_RESTRICTS_1_1=3011
+SCHEMAP_COS_ST_RESTRICTS_1_2=3012
+SCHEMAP_COS_ST_RESTRICTS_1_3_1=3013
+SCHEMAP_COS_ST_RESTRICTS_1_3_2=3014
+SCHEMAP_COS_ST_RESTRICTS_2_1=3015
+SCHEMAP_COS_ST_RESTRICTS_2_3_1_1=3016
+SCHEMAP_COS_ST_RESTRICTS_2_3_1_2=3017
+SCHEMAP_COS_ST_RESTRICTS_2_3_2_1=3018
+SCHEMAP_COS_ST_RESTRICTS_2_3_2_2=3019
+SCHEMAP_COS_ST_RESTRICTS_2_3_2_3=3020
+SCHEMAP_COS_ST_RESTRICTS_2_3_2_4=3021
+SCHEMAP_COS_ST_RESTRICTS_2_3_2_5=3022
+SCHEMAP_COS_ST_RESTRICTS_3_1=3023
+SCHEMAP_COS_ST_RESTRICTS_3_3_1=3024
+SCHEMAP_COS_ST_RESTRICTS_3_3_1_2=3025
+SCHEMAP_COS_ST_RESTRICTS_3_3_2_2=3026
+SCHEMAP_COS_ST_RESTRICTS_3_3_2_1=3027
+SCHEMAP_COS_ST_RESTRICTS_3_3_2_3=3028
+SCHEMAP_COS_ST_RESTRICTS_3_3_2_4=3029
+SCHEMAP_COS_ST_RESTRICTS_3_3_2_5=3030
+SCHEMAP_COS_ST_DERIVED_OK_2_1=3031
+SCHEMAP_COS_ST_DERIVED_OK_2_2=3032
+SCHEMAP_S4S_ELEM_NOT_ALLOWED=3033
+SCHEMAP_S4S_ELEM_MISSING=3034
+SCHEMAP_S4S_ATTR_NOT_ALLOWED=3035
+SCHEMAP_S4S_ATTR_MISSING=3036
+SCHEMAP_S4S_ATTR_INVALID_VALUE=3037
+SCHEMAP_SRC_ELEMENT_1=3038
+SCHEMAP_SRC_ELEMENT_2_1=3039
+SCHEMAP_SRC_ELEMENT_2_2=3040
+SCHEMAP_SRC_ELEMENT_3=3041
+SCHEMAP_P_PROPS_CORRECT_1=3042
+SCHEMAP_P_PROPS_CORRECT_2_1=3043
+SCHEMAP_P_PROPS_CORRECT_2_2=3044
+SCHEMAP_E_PROPS_CORRECT_2=3045
+SCHEMAP_E_PROPS_CORRECT_3=3046
+SCHEMAP_E_PROPS_CORRECT_4=3047
+SCHEMAP_E_PROPS_CORRECT_5=3048
+SCHEMAP_E_PROPS_CORRECT_6=3049
+SCHEMAP_SRC_INCLUDE=3050
+SCHEMAP_SRC_ATTRIBUTE_1=3051
+SCHEMAP_SRC_ATTRIBUTE_2=3052
+SCHEMAP_SRC_ATTRIBUTE_3_1=3053
+SCHEMAP_SRC_ATTRIBUTE_3_2=3054
+SCHEMAP_SRC_ATTRIBUTE_4=3055
+SCHEMAP_NO_XMLNS=3056
+SCHEMAP_NO_XSI=3057
+SCHEMAP_COS_VALID_DEFAULT_1=3058
+SCHEMAP_COS_VALID_DEFAULT_2_1=3059
+SCHEMAP_COS_VALID_DEFAULT_2_2_1=3060
+SCHEMAP_COS_VALID_DEFAULT_2_2_2=3061
+SCHEMAP_CVC_SIMPLE_TYPE=3062
+SCHEMAP_COS_CT_EXTENDS_1_1=3063
+SCHEMAP_SRC_IMPORT_1_1=3064
+SCHEMAP_SRC_IMPORT_1_2=3065
+SCHEMAP_SRC_IMPORT_2=3066
+SCHEMAP_SRC_IMPORT_2_1=3067
+SCHEMAP_SRC_IMPORT_2_2=3068
+SCHEMAP_INTERNAL=3069
+SCHEMAP_NOT_DETERMINISTIC=3070
+SCHEMAP_SRC_ATTRIBUTE_GROUP_1=3071
+SCHEMAP_SRC_ATTRIBUTE_GROUP_2=3072
+SCHEMAP_SRC_ATTRIBUTE_GROUP_3=3073
+SCHEMAP_MG_PROPS_CORRECT_1=3074
+SCHEMAP_MG_PROPS_CORRECT_2=3075
+SCHEMAP_SRC_CT_1=3076
+SCHEMAP_DERIVATION_OK_RESTRICTION_2_1_3=3077
+SCHEMAP_AU_PROPS_CORRECT_2=3078
+SCHEMAP_A_PROPS_CORRECT_2=3079
+SCHEMAP_C_PROPS_CORRECT=3080
+SCHEMAP_SRC_REDEFINE=3081
+SCHEMAP_SRC_IMPORT=3082
+SCHEMAP_WARN_SKIP_SCHEMA=3083
+SCHEMAP_WARN_UNLOCATED_SCHEMA=3084
+SCHEMAP_WARN_ATTR_REDECL_PROH=3085
+SCHEMAP_WARN_ATTR_POINTLESS_PROH=3086
+SCHEMAP_AG_PROPS_CORRECT=3087
+SCHEMAP_COS_CT_EXTENDS_1_2=3088
+SCHEMAP_AU_PROPS_CORRECT=3089
+SCHEMAP_A_PROPS_CORRECT_3=3090
+SCHEMAP_COS_ALL_LIMITED=3091
+SCHEMATRONV_ASSERT=4000
+SCHEMATRONV_REPORT=4001
+MODULE_OPEN=4900
+MODULE_CLOSE=4901
+CHECK_FOUND_ELEMENT=5000
+CHECK_FOUND_ATTRIBUTE=5001
+CHECK_FOUND_TEXT=5002
+CHECK_FOUND_CDATA=5003
+CHECK_FOUND_ENTITYREF=5004
+CHECK_FOUND_ENTITY=5005
+CHECK_FOUND_PI=5006
+CHECK_FOUND_COMMENT=5007
+CHECK_FOUND_DOCTYPE=5008
+CHECK_FOUND_FRAGMENT=5009
+CHECK_FOUND_NOTATION=5010
+CHECK_UNKNOWN_NODE=5011
+CHECK_ENTITY_TYPE=5012
+CHECK_NO_PARENT=5013
+CHECK_NO_DOC=5014
+CHECK_NO_NAME=5015
+CHECK_NO_ELEM=5016
+CHECK_WRONG_DOC=5017
+CHECK_NO_PREV=5018
+CHECK_WRONG_PREV=5019
+CHECK_NO_NEXT=5020
+CHECK_WRONG_NEXT=5021
+CHECK_NOT_DTD=5022
+CHECK_NOT_ATTR=5023
+CHECK_NOT_ATTR_DECL=5024
+CHECK_NOT_ELEM_DECL=5025
+CHECK_NOT_ENTITY_DECL=5026
+CHECK_NOT_NS_DECL=5027
+CHECK_NO_HREF=5028
+CHECK_WRONG_PARENT=5029
+CHECK_NS_SCOPE=5030
+CHECK_NS_ANCESTOR=5031
+CHECK_NOT_UTF8=5032
+CHECK_NO_DICT=5033
+CHECK_NOT_NCNAME=5034
+CHECK_OUTSIDE_DICT=5035
+CHECK_WRONG_NAME=5036
+CHECK_NAME_NOT_NULL=5037
+I18N_NO_NAME=6000
+I18N_NO_HANDLER=6001
+I18N_EXCESS_HANDLER=6002
+I18N_CONV_FAILED=6003
+I18N_NO_OUTPUT=6004
+BUF_OVERFLOW=7000
+"""
+
+cdef object __RELAXNG_ERROR_TYPES = """\
+RELAXNG_OK=0
+RELAXNG_ERR_MEMORY=1
+RELAXNG_ERR_TYPE=2
+RELAXNG_ERR_TYPEVAL=3
+RELAXNG_ERR_DUPID=4
+RELAXNG_ERR_TYPECMP=5
+RELAXNG_ERR_NOSTATE=6
+RELAXNG_ERR_NODEFINE=7
+RELAXNG_ERR_LISTEXTRA=8
+RELAXNG_ERR_LISTEMPTY=9
+RELAXNG_ERR_INTERNODATA=10
+RELAXNG_ERR_INTERSEQ=11
+RELAXNG_ERR_INTEREXTRA=12
+RELAXNG_ERR_ELEMNAME=13
+RELAXNG_ERR_ATTRNAME=14
+RELAXNG_ERR_ELEMNONS=15
+RELAXNG_ERR_ATTRNONS=16
+RELAXNG_ERR_ELEMWRONGNS=17
+RELAXNG_ERR_ATTRWRONGNS=18
+RELAXNG_ERR_ELEMEXTRANS=19
+RELAXNG_ERR_ATTREXTRANS=20
+RELAXNG_ERR_ELEMNOTEMPTY=21
+RELAXNG_ERR_NOELEM=22
+RELAXNG_ERR_NOTELEM=23
+RELAXNG_ERR_ATTRVALID=24
+RELAXNG_ERR_CONTENTVALID=25
+RELAXNG_ERR_EXTRACONTENT=26
+RELAXNG_ERR_INVALIDATTR=27
+RELAXNG_ERR_DATAELEM=28
+RELAXNG_ERR_VALELEM=29
+RELAXNG_ERR_LISTELEM=30
+RELAXNG_ERR_DATATYPE=31
+RELAXNG_ERR_VALUE=32
+RELAXNG_ERR_LIST=33
+RELAXNG_ERR_NOGRAMMAR=34
+RELAXNG_ERR_EXTRADATA=35
+RELAXNG_ERR_LACKDATA=36
+RELAXNG_ERR_INTERNAL=37
+RELAXNG_ERR_ELEMWRONG=38
+RELAXNG_ERR_TEXTWRONG=39
+"""
+# --- END: GENERATED CONSTANTS ---
+
+__initErrorConstants()
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xpath.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xpath.pxi
new file mode 100644
index 0000000000000000000000000000000000000000..352f63134734780e5a9c869ccb59b4cb4e4ade40
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xpath.pxi
@@ -0,0 +1,487 @@
+# XPath evaluation
+
+class XPathSyntaxError(LxmlSyntaxError, XPathError):
+ pass
+
+################################################################################
+# XPath
+
+cdef object _XPATH_SYNTAX_ERRORS = (
+ xmlerror.XML_XPATH_NUMBER_ERROR,
+ xmlerror.XML_XPATH_UNFINISHED_LITERAL_ERROR,
+ xmlerror.XML_XPATH_VARIABLE_REF_ERROR,
+ xmlerror.XML_XPATH_INVALID_PREDICATE_ERROR,
+ xmlerror.XML_XPATH_UNCLOSED_ERROR,
+ xmlerror.XML_XPATH_INVALID_CHAR_ERROR
+)
+
+cdef object _XPATH_EVAL_ERRORS = (
+ xmlerror.XML_XPATH_UNDEF_VARIABLE_ERROR,
+ xmlerror.XML_XPATH_UNDEF_PREFIX_ERROR,
+ xmlerror.XML_XPATH_UNKNOWN_FUNC_ERROR,
+ xmlerror.XML_XPATH_INVALID_OPERAND,
+ xmlerror.XML_XPATH_INVALID_TYPE,
+ xmlerror.XML_XPATH_INVALID_ARITY,
+ xmlerror.XML_XPATH_INVALID_CTXT_SIZE,
+ xmlerror.XML_XPATH_INVALID_CTXT_POSITION
+)
+
+cdef int _register_xpath_function(void* ctxt, name_utf, ns_utf) noexcept:
+ if ns_utf is None:
+ return xpath.xmlXPathRegisterFunc(
+ ctxt, _xcstr(name_utf),
+ _xpath_function_call)
+ else:
+ return xpath.xmlXPathRegisterFuncNS(
+ ctxt, _xcstr(name_utf), _xcstr(ns_utf),
+ _xpath_function_call)
+
+cdef int _unregister_xpath_function(void* ctxt, name_utf, ns_utf) noexcept:
+ if ns_utf is None:
+ return xpath.xmlXPathRegisterFunc(
+ ctxt, _xcstr(name_utf), NULL)
+ else:
+ return xpath.xmlXPathRegisterFuncNS(
+ ctxt, _xcstr(name_utf), _xcstr(ns_utf), NULL)
+
+
+@cython.final
+@cython.internal
+cdef class _XPathContext(_BaseContext):
+ cdef object _variables
+ def __init__(self, namespaces, extensions, error_log, enable_regexp, variables,
+ build_smart_strings):
+ self._variables = variables
+ _BaseContext.__init__(self, namespaces, extensions, error_log, enable_regexp,
+ build_smart_strings)
+
+ cdef set_context(self, xpath.xmlXPathContext* xpathCtxt):
+ self._set_xpath_context(xpathCtxt)
+ # This would be a good place to set up the XPath parser dict, but
+ # we cannot use the current thread dict as we do not know which
+ # thread will execute the XPath evaluator - so, no dict for now.
+ self.registerLocalNamespaces()
+ self.registerLocalFunctions(xpathCtxt, _register_xpath_function)
+
+ cdef register_context(self, _Document doc):
+ self._register_context(doc)
+ self.registerGlobalNamespaces()
+ self.registerGlobalFunctions(self._xpathCtxt, _register_xpath_function)
+ self.registerExsltFunctions()
+ if self._variables is not None:
+ self.registerVariables(self._variables)
+
+ cdef unregister_context(self):
+ self.unregisterGlobalFunctions(
+ self._xpathCtxt, _unregister_xpath_function)
+ self.unregisterGlobalNamespaces()
+ xpath.xmlXPathRegisteredVariablesCleanup(self._xpathCtxt)
+ self._cleanup_context()
+
+ cdef void registerExsltFunctions(self) noexcept:
+ if xslt.LIBXSLT_VERSION < 10125:
+ # we'd only execute dummy functions anyway
+ return
+ tree.xmlHashScan(
+ self._xpathCtxt.nsHash, _registerExsltFunctionsForNamespaces,
+ self._xpathCtxt)
+
+ cdef registerVariables(self, variable_dict):
+ for name, value in variable_dict.items():
+ name_utf = self._to_utf(name)
+ xpath.xmlXPathRegisterVariable(
+ self._xpathCtxt, _xcstr(name_utf), _wrapXPathObject(value, None, None))
+
+ cdef registerVariable(self, name, value):
+ name_utf = self._to_utf(name)
+ xpath.xmlXPathRegisterVariable(
+ self._xpathCtxt, _xcstr(name_utf), _wrapXPathObject(value, None, None))
+
+
+cdef void _registerExsltFunctionsForNamespaces(
+ void* _c_href, void* _ctxt, const_xmlChar* c_prefix) noexcept:
+ c_href = _c_href
+ ctxt = _ctxt
+
+ if tree.xmlStrcmp(c_href, xslt.EXSLT_DATE_NAMESPACE) == 0:
+ xslt.exsltDateXpathCtxtRegister(ctxt, c_prefix)
+ elif tree.xmlStrcmp(c_href, xslt.EXSLT_SETS_NAMESPACE) == 0:
+ xslt.exsltSetsXpathCtxtRegister(ctxt, c_prefix)
+ elif tree.xmlStrcmp(c_href, xslt.EXSLT_MATH_NAMESPACE) == 0:
+ xslt.exsltMathXpathCtxtRegister(ctxt, c_prefix)
+ elif tree.xmlStrcmp(c_href, xslt.EXSLT_STRINGS_NAMESPACE) == 0:
+ xslt.exsltStrXpathCtxtRegister(ctxt, c_prefix)
+
+
+cdef class _XPathEvaluatorBase:
+ cdef xpath.xmlXPathContext* _xpathCtxt
+ cdef _XPathContext _context
+ cdef python.PyThread_type_lock _eval_lock
+ cdef _ErrorLog _error_log
+ def __cinit__(self):
+ self._xpathCtxt = NULL
+ if config.ENABLE_THREADING:
+ self._eval_lock = python.PyThread_allocate_lock()
+ if self._eval_lock is NULL:
+ raise MemoryError()
+ self._error_log = _ErrorLog()
+
+ def __init__(self, namespaces, extensions, enable_regexp,
+ smart_strings):
+ self._context = _XPathContext(namespaces, extensions, self._error_log,
+ enable_regexp, None, smart_strings)
+
+ @property
+ def error_log(self):
+ assert self._error_log is not None, "XPath evaluator not initialised"
+ return self._error_log.copy()
+
+ def __dealloc__(self):
+ if self._xpathCtxt is not NULL:
+ xpath.xmlXPathFreeContext(self._xpathCtxt)
+ if config.ENABLE_THREADING:
+ if self._eval_lock is not NULL:
+ python.PyThread_free_lock(self._eval_lock)
+
+ cdef set_context(self, xpath.xmlXPathContext* xpathCtxt):
+ self._xpathCtxt = xpathCtxt
+ self._context.set_context(xpathCtxt)
+
+ cdef bint _checkAbsolutePath(self, char* path) noexcept:
+ cdef char c
+ if path is NULL:
+ return 0
+ c = path[0]
+ while c == c' ' or c == c'\t':
+ path = path + 1
+ c = path[0]
+ return c == c'/'
+
+ @cython.final
+ cdef int _lock(self) except -1:
+ cdef int result
+ if config.ENABLE_THREADING and self._eval_lock != NULL:
+ with nogil:
+ result = python.PyThread_acquire_lock(
+ self._eval_lock, python.WAIT_LOCK)
+ if result == 0:
+ raise XPathError, "XPath evaluator locking failed"
+ return 0
+
+ @cython.final
+ cdef void _unlock(self) noexcept:
+ if config.ENABLE_THREADING and self._eval_lock != NULL:
+ python.PyThread_release_lock(self._eval_lock)
+
+ cdef _build_parse_error(self):
+ cdef _BaseErrorLog entries
+ entries = self._error_log.filter_types(_XPATH_SYNTAX_ERRORS)
+ if entries:
+ message = entries._buildExceptionMessage(None)
+ if message is not None:
+ return XPathSyntaxError(message, self._error_log)
+ return XPathSyntaxError(
+ self._error_log._buildExceptionMessage("Error in xpath expression"),
+ self._error_log)
+
+ cdef _build_eval_error(self):
+ cdef _BaseErrorLog entries
+ entries = self._error_log.filter_types(_XPATH_EVAL_ERRORS)
+ if not entries:
+ entries = self._error_log.filter_types(_XPATH_SYNTAX_ERRORS)
+ if entries:
+ message = entries._buildExceptionMessage(None)
+ if message is not None:
+ return XPathEvalError(message, self._error_log)
+ return XPathEvalError(
+ self._error_log._buildExceptionMessage("Error in xpath expression"),
+ self._error_log)
+
+ cdef object _handle_result(self, xpath.xmlXPathObject* xpathObj, _Document doc):
+ if self._context._exc._has_raised():
+ if xpathObj is not NULL:
+ _freeXPathObject(xpathObj)
+ xpathObj = NULL
+ self._context._release_temp_refs()
+ self._context._exc._raise_if_stored()
+
+ if xpathObj is NULL:
+ self._context._release_temp_refs()
+ raise self._build_eval_error()
+
+ try:
+ result = _unwrapXPathObject(xpathObj, doc, self._context)
+ finally:
+ _freeXPathObject(xpathObj)
+ self._context._release_temp_refs()
+
+ return result
+
+
+cdef class XPathElementEvaluator(_XPathEvaluatorBase):
+ """XPathElementEvaluator(self, element, namespaces=None, extensions=None, regexp=True, smart_strings=True)
+ Create an XPath evaluator for an element.
+
+ Absolute XPath expressions (starting with '/') will be evaluated against
+ the ElementTree as returned by getroottree().
+
+ Additional namespace declarations can be passed with the
+ 'namespace' keyword argument. EXSLT regular expression support
+ can be disabled with the 'regexp' boolean keyword (defaults to
+ True). Smart strings will be returned for string results unless
+ you pass ``smart_strings=False``.
+ """
+ cdef _Element _element
+ def __init__(self, _Element element not None, *, namespaces=None,
+ extensions=None, regexp=True, smart_strings=True):
+ cdef xpath.xmlXPathContext* xpathCtxt
+ cdef int ns_register_status
+ cdef _Document doc
+ _assertValidNode(element)
+ _assertValidDoc(element._doc)
+ self._element = element
+ doc = element._doc
+ _XPathEvaluatorBase.__init__(self, namespaces, extensions,
+ regexp, smart_strings)
+ xpathCtxt = xpath.xmlXPathNewContext(doc._c_doc)
+ if xpathCtxt is NULL:
+ raise MemoryError()
+ self.set_context(xpathCtxt)
+
+ def register_namespace(self, prefix, uri):
+ """Register a namespace with the XPath context.
+ """
+ assert self._xpathCtxt is not NULL, "XPath context not initialised"
+ self._context.addNamespace(prefix, uri)
+
+ def register_namespaces(self, namespaces):
+ """Register a prefix -> uri dict.
+ """
+ assert self._xpathCtxt is not NULL, "XPath context not initialised"
+ for prefix, uri in namespaces.items():
+ self._context.addNamespace(prefix, uri)
+
+ def __call__(self, _path, **_variables):
+ """__call__(self, _path, **_variables)
+
+ Evaluate an XPath expression on the document.
+
+ Variables may be provided as keyword arguments. Note that namespaces
+ are currently not supported for variables.
+
+ Absolute XPath expressions (starting with '/') will be evaluated
+ against the ElementTree as returned by getroottree().
+ """
+ cdef xpath.xmlXPathObject* xpathObj
+ cdef _Document doc
+ assert self._xpathCtxt is not NULL, "XPath context not initialised"
+ path = _utf8(_path)
+ doc = self._element._doc
+
+ self._lock()
+ self._xpathCtxt.node = self._element._c_node
+ try:
+ self._context.register_context(doc)
+ self._context.registerVariables(_variables)
+ c_path = _xcstr(path)
+ with nogil:
+ xpathObj = xpath.xmlXPathEvalExpression(
+ c_path, self._xpathCtxt)
+ result = self._handle_result(xpathObj, doc)
+ finally:
+ self._context.unregister_context()
+ self._unlock()
+
+ return result
+
+
+cdef class XPathDocumentEvaluator(XPathElementEvaluator):
+ """XPathDocumentEvaluator(self, etree, namespaces=None, extensions=None, regexp=True, smart_strings=True)
+ Create an XPath evaluator for an ElementTree.
+
+ Additional namespace declarations can be passed with the
+ 'namespace' keyword argument. EXSLT regular expression support
+ can be disabled with the 'regexp' boolean keyword (defaults to
+ True). Smart strings will be returned for string results unless
+ you pass ``smart_strings=False``.
+ """
+ def __init__(self, _ElementTree etree not None, *, namespaces=None,
+ extensions=None, regexp=True, smart_strings=True):
+ XPathElementEvaluator.__init__(
+ self, etree._context_node, namespaces=namespaces,
+ extensions=extensions, regexp=regexp,
+ smart_strings=smart_strings)
+
+ def __call__(self, _path, **_variables):
+ """__call__(self, _path, **_variables)
+
+ Evaluate an XPath expression on the document.
+
+ Variables may be provided as keyword arguments. Note that namespaces
+ are currently not supported for variables.
+ """
+ cdef xpath.xmlXPathObject* xpathObj
+ cdef xmlDoc* c_doc
+ cdef _Document doc
+ assert self._xpathCtxt is not NULL, "XPath context not initialised"
+ path = _utf8(_path)
+ doc = self._element._doc
+
+ self._lock()
+ try:
+ self._context.register_context(doc)
+ c_doc = _fakeRootDoc(doc._c_doc, self._element._c_node)
+ try:
+ self._context.registerVariables(_variables)
+ c_path = _xcstr(path)
+ with nogil:
+ self._xpathCtxt.doc = c_doc
+ self._xpathCtxt.node = tree.xmlDocGetRootElement(c_doc)
+ xpathObj = xpath.xmlXPathEvalExpression(
+ c_path, self._xpathCtxt)
+ result = self._handle_result(xpathObj, doc)
+ finally:
+ _destroyFakeDoc(doc._c_doc, c_doc)
+ self._context.unregister_context()
+ finally:
+ self._unlock()
+
+ return result
+
+
+def XPathEvaluator(etree_or_element, *, namespaces=None, extensions=None,
+ regexp=True, smart_strings=True):
+ """XPathEvaluator(etree_or_element, namespaces=None, extensions=None, regexp=True, smart_strings=True)
+
+ Creates an XPath evaluator for an ElementTree or an Element.
+
+ The resulting object can be called with an XPath expression as argument
+ and XPath variables provided as keyword arguments.
+
+ Additional namespace declarations can be passed with the
+ 'namespace' keyword argument. EXSLT regular expression support
+ can be disabled with the 'regexp' boolean keyword (defaults to
+ True). Smart strings will be returned for string results unless
+ you pass ``smart_strings=False``.
+ """
+ if isinstance(etree_or_element, _ElementTree):
+ return XPathDocumentEvaluator(
+ etree_or_element, namespaces=namespaces,
+ extensions=extensions, regexp=regexp, smart_strings=smart_strings)
+ else:
+ return XPathElementEvaluator(
+ etree_or_element, namespaces=namespaces,
+ extensions=extensions, regexp=regexp, smart_strings=smart_strings)
+
+
+cdef class XPath(_XPathEvaluatorBase):
+ """XPath(self, path, namespaces=None, extensions=None, regexp=True, smart_strings=True)
+ A compiled XPath expression that can be called on Elements and ElementTrees.
+
+ Besides the XPath expression, you can pass prefix-namespace
+ mappings and extension functions to the constructor through the
+ keyword arguments ``namespaces`` and ``extensions``. EXSLT
+ regular expression support can be disabled with the 'regexp'
+ boolean keyword (defaults to True). Smart strings will be
+ returned for string results unless you pass
+ ``smart_strings=False``.
+ """
+ cdef xpath.xmlXPathCompExpr* _xpath
+ cdef bytes _path
+ def __cinit__(self):
+ self._xpath = NULL
+
+ def __init__(self, path, *, namespaces=None, extensions=None,
+ regexp=True, smart_strings=True):
+ cdef xpath.xmlXPathContext* xpathCtxt
+ _XPathEvaluatorBase.__init__(self, namespaces, extensions,
+ regexp, smart_strings)
+ self._path = _utf8(path)
+ xpathCtxt = xpath.xmlXPathNewContext(NULL)
+ if xpathCtxt is NULL:
+ raise MemoryError()
+ self.set_context(xpathCtxt)
+ self._xpath = xpath.xmlXPathCtxtCompile(xpathCtxt, _xcstr(self._path))
+ if self._xpath is NULL:
+ raise self._build_parse_error()
+
+ def __call__(self, _etree_or_element, **_variables):
+ "__call__(self, _etree_or_element, **_variables)"
+ cdef xpath.xmlXPathObject* xpathObj
+ cdef _Document document
+ cdef _Element element
+
+ assert self._xpathCtxt is not NULL, "XPath context not initialised"
+ document = _documentOrRaise(_etree_or_element)
+ element = _rootNodeOrRaise(_etree_or_element)
+
+ self._lock()
+ self._xpathCtxt.doc = document._c_doc
+ self._xpathCtxt.node = element._c_node
+
+ try:
+ self._context.register_context(document)
+ self._context.registerVariables(_variables)
+ with nogil:
+ xpathObj = xpath.xmlXPathCompiledEval(
+ self._xpath, self._xpathCtxt)
+ result = self._handle_result(xpathObj, document)
+ finally:
+ self._context.unregister_context()
+ self._unlock()
+ return result
+
+ @property
+ def path(self):
+ """The literal XPath expression.
+ """
+ return self._path.decode('UTF-8')
+
+ def __dealloc__(self):
+ if self._xpath is not NULL:
+ xpath.xmlXPathFreeCompExpr(self._xpath)
+
+ def __repr__(self):
+ return self.path
+
+
+cdef object _replace_strings = re.compile(b'("[^"]*")|(\'[^\']*\')').sub
+cdef object _find_namespaces = re.compile(b'({[^}]+})').findall
+
+cdef class ETXPath(XPath):
+ """ETXPath(self, path, extensions=None, regexp=True, smart_strings=True)
+ Special XPath class that supports the ElementTree {uri} notation for namespaces.
+
+ Note that this class does not accept the ``namespace`` keyword
+ argument. All namespaces must be passed as part of the path
+ string. Smart strings will be returned for string results unless
+ you pass ``smart_strings=False``.
+ """
+ def __init__(self, path, *, extensions=None, regexp=True,
+ smart_strings=True):
+ path, namespaces = self._nsextract_path(path)
+ XPath.__init__(self, path, namespaces=namespaces,
+ extensions=extensions, regexp=regexp,
+ smart_strings=smart_strings)
+
+ cdef _nsextract_path(self, path):
+ # replace {namespaces} by new prefixes
+ cdef dict namespaces = {}
+ cdef list namespace_defs = []
+ cdef int i
+ path_utf = _utf8(path)
+ stripped_path = _replace_strings(b'', path_utf) # remove string literals
+ i = 1
+ for namespace_def in _find_namespaces(stripped_path):
+ if namespace_def not in namespace_defs:
+ prefix = python.PyBytes_FromFormat("__xpp%02d", i)
+ i += 1
+ namespace_defs.append(namespace_def)
+ namespace = namespace_def[1:-1] # remove '{}'
+ namespace = (namespace).decode('utf8')
+ namespaces[prefix.decode('utf8')] = namespace
+ prefix_str = prefix + b':'
+ # FIXME: this also replaces {namespaces} within strings!
+ path_utf = path_utf.replace(namespace_def, prefix_str)
+ path = path_utf.decode('utf8')
+ return path, namespaces
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xsltext.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xsltext.pxi
new file mode 100644
index 0000000000000000000000000000000000000000..21894b9ef5859a455fd2f9f4443e805818b94517
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xsltext.pxi
@@ -0,0 +1,242 @@
+# XSLT extension elements
+
+cdef class XSLTExtension:
+ """Base class of an XSLT extension element.
+ """
+ def execute(self, context, self_node, input_node, output_parent):
+ """execute(self, context, self_node, input_node, output_parent)
+ Execute this extension element.
+
+ Subclasses must override this method. They may append
+ elements to the `output_parent` element here, or set its text
+ content. To this end, the `input_node` provides read-only
+ access to the current node in the input document, and the
+ `self_node` points to the extension element in the stylesheet.
+
+ Note that the `output_parent` parameter may be `None` if there
+ is no parent element in the current context (e.g. no content
+ was added to the output tree yet).
+ """
+ pass
+
+ def apply_templates(self, _XSLTContext context not None, node, output_parent=None,
+ *, elements_only=False, remove_blank_text=False):
+ """apply_templates(self, context, node, output_parent=None, elements_only=False, remove_blank_text=False)
+
+ Call this method to retrieve the result of applying templates
+ to an element.
+
+ The return value is a list of elements or text strings that
+ were generated by the XSLT processor. If you pass
+ ``elements_only=True``, strings will be discarded from the result
+ list. The option ``remove_blank_text=True`` will only discard
+ strings that consist entirely of whitespace (e.g. formatting).
+ These options do not apply to Elements, only to bare string results.
+
+ If you pass an Element as `output_parent` parameter, the result
+ will instead be appended to the element (including attributes
+ etc.) and the return value will be `None`. This is a safe way
+ to generate content into the output document directly, without
+ having to take care of special values like text or attributes.
+ Note that the string discarding options will be ignored in this
+ case.
+ """
+ cdef xmlNode* c_parent
+ cdef xmlNode* c_node
+ cdef xmlNode* c_context_node
+ assert context._xsltCtxt is not NULL, "XSLT context not initialised"
+ c_context_node = _roNodeOf(node)
+ #assert c_context_node.doc is context._xsltContext.node.doc, \
+ # "switching input documents during transformation is not currently supported"
+
+ if output_parent is not None:
+ c_parent = _nonRoNodeOf(output_parent)
+ else:
+ c_parent = tree.xmlNewDocNode(
+ context._xsltCtxt.output, NULL, "fake-parent", NULL)
+
+ c_node = context._xsltCtxt.insert
+ context._xsltCtxt.insert = c_parent
+ xslt.xsltProcessOneNode(
+ context._xsltCtxt, c_context_node, NULL)
+ context._xsltCtxt.insert = c_node
+
+ if output_parent is not None:
+ return None
+
+ try:
+ return self._collectXSLTResultContent(
+ context, c_parent, elements_only, remove_blank_text)
+ finally:
+ # free all intermediate nodes that will not be freed by proxies
+ tree.xmlFreeNode(c_parent)
+
+ def process_children(self, _XSLTContext context not None, output_parent=None,
+ *, elements_only=False, remove_blank_text=False):
+ """process_children(self, context, output_parent=None, elements_only=False, remove_blank_text=False)
+
+ Call this method to process the XSLT content of the extension
+ element itself.
+
+ The return value is a list of elements or text strings that
+ were generated by the XSLT processor. If you pass
+ ``elements_only=True``, strings will be discarded from the result
+ list. The option ``remove_blank_text=True`` will only discard
+ strings that consist entirely of whitespace (e.g. formatting).
+ These options do not apply to Elements, only to bare string results.
+
+ If you pass an Element as `output_parent` parameter, the result
+ will instead be appended to the element (including attributes
+ etc.) and the return value will be `None`. This is a safe way
+ to generate content into the output document directly, without
+ having to take care of special values like text or attributes.
+ Note that the string discarding options will be ignored in this
+ case.
+ """
+ cdef xmlNode* c_parent
+ cdef xslt.xsltTransformContext* c_ctxt = context._xsltCtxt
+ cdef xmlNode* c_old_output_parent = c_ctxt.insert
+ assert context._xsltCtxt is not NULL, "XSLT context not initialised"
+
+ # output_parent node is used for adding results instead of
+ # elements list used in apply_templates, that's easier and allows to
+ # use attributes added to extension element with .
+
+ if output_parent is not None:
+ c_parent = _nonRoNodeOf(output_parent)
+ else:
+ c_parent = tree.xmlNewDocNode(
+ context._xsltCtxt.output, NULL, "fake-parent", NULL)
+
+ c_ctxt.insert = c_parent
+ xslt.xsltApplyOneTemplate(c_ctxt,
+ c_ctxt.node, c_ctxt.inst.children, NULL, NULL)
+ c_ctxt.insert = c_old_output_parent
+
+ if output_parent is not None:
+ return None
+
+ try:
+ return self._collectXSLTResultContent(
+ context, c_parent, elements_only, remove_blank_text)
+ finally:
+ # free all intermediate nodes that will not be freed by proxies
+ tree.xmlFreeNode(c_parent)
+
+ cdef _collectXSLTResultContent(self, _XSLTContext context, xmlNode* c_parent,
+ bint elements_only, bint remove_blank_text):
+ cdef xmlNode* c_node
+ cdef xmlNode* c_next
+ cdef _ReadOnlyProxy proxy
+ cdef list results = [] # or maybe _collectAttributes(c_parent, 2) ?
+ c_node = c_parent.children
+ while c_node is not NULL:
+ c_next = c_node.next
+ if c_node.type == tree.XML_TEXT_NODE:
+ if not elements_only:
+ s = funicode(c_node.content)
+ if not remove_blank_text or s.strip():
+ results.append(s)
+ s = None
+ elif c_node.type == tree.XML_ELEMENT_NODE:
+ proxy = _newReadOnlyProxy(
+ context._extension_element_proxy, c_node)
+ results.append(proxy)
+ # unlink node and make sure it will be freed later on
+ tree.xmlUnlinkNode(c_node)
+ proxy.free_after_use()
+ else:
+ raise TypeError, \
+ f"unsupported XSLT result type: {c_node.type}"
+ c_node = c_next
+ return results
+
+
+cdef _registerXSLTExtensions(xslt.xsltTransformContext* c_ctxt,
+ extension_dict):
+ for ns_utf, name_utf in extension_dict:
+ xslt.xsltRegisterExtElement(
+ c_ctxt, _xcstr(name_utf), _xcstr(ns_utf),
+ _callExtensionElement)
+
+cdef void _callExtensionElement(xslt.xsltTransformContext* c_ctxt,
+ xmlNode* c_context_node,
+ xmlNode* c_inst_node,
+ void* dummy) noexcept with gil:
+ cdef _XSLTContext context
+ cdef XSLTExtension extension
+ cdef python.PyObject* dict_result
+ cdef xmlNode* c_node
+ cdef _ReadOnlyProxy context_node = None, self_node = None
+ cdef object output_parent # not restricted to ro-nodes
+ c_uri = _getNs(c_inst_node)
+ if c_uri is NULL:
+ # not allowed, and should never happen
+ return
+ if c_ctxt.xpathCtxt.userData is NULL:
+ # just for safety, should never happen
+ return
+ context = <_XSLTContext>c_ctxt.xpathCtxt.userData
+ try:
+ try:
+ dict_result = python.PyDict_GetItem(
+ context._extension_elements, (c_uri, c_inst_node.name))
+ if dict_result is NULL:
+ raise KeyError, f"extension element {funicode(c_inst_node.name)} not found"
+ extension = dict_result
+
+ try:
+ # build the context proxy nodes
+ self_node = _newReadOnlyProxy(None, c_inst_node)
+ if _isElement(c_ctxt.insert):
+ output_parent = _newAppendOnlyProxy(self_node, c_ctxt.insert)
+ else:
+ # may be the document node or other stuff
+ output_parent = _newOpaqueAppendOnlyNodeWrapper(c_ctxt.insert)
+ if c_context_node.type in (tree.XML_DOCUMENT_NODE,
+ tree.XML_HTML_DOCUMENT_NODE):
+ c_node = tree.xmlDocGetRootElement(c_context_node)
+ if c_node is not NULL:
+ context_node = _newReadOnlyProxy(self_node, c_node)
+ else:
+ context_node = None
+ elif c_context_node.type in (tree.XML_ATTRIBUTE_NODE,
+ tree.XML_TEXT_NODE,
+ tree.XML_CDATA_SECTION_NODE):
+ # this isn't easy to support using read-only
+ # nodes, as the smart-string factory must
+ # instantiate the parent proxy somehow...
+ raise TypeError(f"Unsupported element type: {c_context_node.type}")
+ else:
+ context_node = _newReadOnlyProxy(self_node, c_context_node)
+
+ # run the XSLT extension
+ context._extension_element_proxy = self_node
+ extension.execute(context, self_node, context_node, output_parent)
+ finally:
+ context._extension_element_proxy = None
+ if self_node is not None:
+ _freeReadOnlyProxies(self_node)
+ except Exception as e:
+ try:
+ e = unicode(e).encode("UTF-8")
+ except:
+ e = repr(e).encode("UTF-8")
+ message = python.PyBytes_FromFormat(
+ "Error executing extension element '%s': %s",
+ c_inst_node.name, _cstr(e))
+ xslt.xsltTransformError(c_ctxt, NULL, c_inst_node, "%s", message)
+ context._exc._store_raised()
+ except:
+ # just in case
+ message = python.PyBytes_FromFormat(
+ "Error executing extension element '%s'", c_inst_node.name)
+ xslt.xsltTransformError(c_ctxt, NULL, c_inst_node, "%s", message)
+ context._exc._store_raised()
+ except:
+ # no Python functions here - everything can fail...
+ xslt.xsltTransformError(c_ctxt, NULL, c_inst_node,
+ "Error during XSLT extension element evaluation")
+ context._exc._store_raised()
+ finally:
+ return # swallow any further exceptions
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..181ca3b0c6bee682ebce2ba52772f2ce1f9b4d30
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/__init__.py
@@ -0,0 +1,975 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# When adding a new object to this init, remember to add it twice: once inside the `_import_structure` dictionary and
+# once inside the `if TYPE_CHECKING` branch. The `TYPE_CHECKING` should have import statements as usual, but they are
+# only there for type checking. The `_import_structure` is a dictionary submodule to list of object names, and is used
+# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
+# in the namespace without actually importing anything (and especially none of the backends).
+
+__version__ = "4.57.6"
+
+from pathlib import Path
+from typing import TYPE_CHECKING
+
+# Check the dependencies satisfy the minimal versions required.
+from . import dependency_versions_check
+from .utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_essentia_available,
+ is_g2p_en_available,
+ is_librosa_available,
+ is_mistral_common_available,
+ is_mlx_available,
+ is_pretty_midi_available,
+)
+
+# Note: the following symbols are deliberately exported with `as`
+# so that mypy, pylint or other static linters can recognize them,
+# given that they are not exported using `__all__` in this file.
+from .utils import is_bitsandbytes_available as is_bitsandbytes_available
+from .utils import is_flax_available as is_flax_available
+from .utils import is_keras_nlp_available as is_keras_nlp_available
+from .utils import is_scipy_available as is_scipy_available
+from .utils import is_sentencepiece_available as is_sentencepiece_available
+from .utils import is_speech_available as is_speech_available
+from .utils import is_tensorflow_text_available as is_tensorflow_text_available
+from .utils import is_tf_available as is_tf_available
+from .utils import is_timm_available as is_timm_available
+from .utils import is_tokenizers_available as is_tokenizers_available
+from .utils import is_torch_available as is_torch_available
+from .utils import is_torchaudio_available as is_torchaudio_available
+from .utils import is_torchvision_available as is_torchvision_available
+from .utils import is_vision_available as is_vision_available
+from .utils import logging as logging
+from .utils.import_utils import define_import_structure
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+# Base objects, independent of any specific backend
+_import_structure = {
+ "audio_utils": [],
+ "commands": [],
+ "configuration_utils": ["PretrainedConfig"],
+ "convert_graph_to_onnx": [],
+ "convert_slow_tokenizers_checkpoints_to_fast": [],
+ "convert_tf_hub_seq_to_seq_bert_to_pytorch": [],
+ "data": [
+ "DataProcessor",
+ "InputExample",
+ "InputFeatures",
+ "SingleSentenceClassificationProcessor",
+ "SquadExample",
+ "SquadFeatures",
+ "SquadV1Processor",
+ "SquadV2Processor",
+ "glue_compute_metrics",
+ "glue_convert_examples_to_features",
+ "glue_output_modes",
+ "glue_processors",
+ "glue_tasks_num_labels",
+ "squad_convert_examples_to_features",
+ "xnli_compute_metrics",
+ "xnli_output_modes",
+ "xnli_processors",
+ "xnli_tasks_num_labels",
+ ],
+ "data.data_collator": [
+ "DataCollator",
+ "DataCollatorForLanguageModeling",
+ "DataCollatorForMultipleChoice",
+ "DataCollatorForPermutationLanguageModeling",
+ "DataCollatorForSeq2Seq",
+ "DataCollatorForSOP",
+ "DataCollatorForTokenClassification",
+ "DataCollatorForWholeWordMask",
+ "DataCollatorWithFlattening",
+ "DataCollatorWithPadding",
+ "DefaultDataCollator",
+ "default_data_collator",
+ ],
+ "data.metrics": [],
+ "data.processors": [],
+ "debug_utils": [],
+ "dependency_versions_check": [],
+ "dependency_versions_table": [],
+ "dynamic_module_utils": [],
+ "feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
+ "feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
+ "file_utils": [],
+ "generation": [
+ "AsyncTextIteratorStreamer",
+ "CompileConfig",
+ "GenerationConfig",
+ "TextIteratorStreamer",
+ "TextStreamer",
+ "WatermarkingConfig",
+ ],
+ "hf_argparser": ["HfArgumentParser"],
+ "hyperparameter_search": [],
+ "image_transforms": [],
+ "integrations": [
+ "is_clearml_available",
+ "is_comet_available",
+ "is_dvclive_available",
+ "is_neptune_available",
+ "is_optuna_available",
+ "is_ray_available",
+ "is_ray_tune_available",
+ "is_sigopt_available",
+ "is_swanlab_available",
+ "is_tensorboard_available",
+ "is_trackio_available",
+ "is_wandb_available",
+ ],
+ "loss": [],
+ "modelcard": ["ModelCard"],
+ # Losses
+ "modeling_tf_pytorch_utils": [
+ "convert_tf_weight_name_to_pt_weight_name",
+ "load_pytorch_checkpoint_in_tf2_model",
+ "load_pytorch_model_in_tf2_model",
+ "load_pytorch_weights_in_tf2_model",
+ "load_tf2_checkpoint_in_pytorch_model",
+ "load_tf2_model_in_pytorch_model",
+ "load_tf2_weights_in_pytorch_model",
+ ],
+ # Models
+ "onnx": [],
+ "pipelines": [
+ "AudioClassificationPipeline",
+ "AutomaticSpeechRecognitionPipeline",
+ "CsvPipelineDataFormat",
+ "DepthEstimationPipeline",
+ "DocumentQuestionAnsweringPipeline",
+ "FeatureExtractionPipeline",
+ "FillMaskPipeline",
+ "ImageClassificationPipeline",
+ "ImageFeatureExtractionPipeline",
+ "ImageSegmentationPipeline",
+ "ImageTextToTextPipeline",
+ "ImageToImagePipeline",
+ "ImageToTextPipeline",
+ "JsonPipelineDataFormat",
+ "KeypointMatchingPipeline",
+ "MaskGenerationPipeline",
+ "NerPipeline",
+ "ObjectDetectionPipeline",
+ "PipedPipelineDataFormat",
+ "Pipeline",
+ "PipelineDataFormat",
+ "QuestionAnsweringPipeline",
+ "SummarizationPipeline",
+ "TableQuestionAnsweringPipeline",
+ "Text2TextGenerationPipeline",
+ "TextClassificationPipeline",
+ "TextGenerationPipeline",
+ "TextToAudioPipeline",
+ "TokenClassificationPipeline",
+ "TranslationPipeline",
+ "VideoClassificationPipeline",
+ "VisualQuestionAnsweringPipeline",
+ "ZeroShotAudioClassificationPipeline",
+ "ZeroShotClassificationPipeline",
+ "ZeroShotImageClassificationPipeline",
+ "ZeroShotObjectDetectionPipeline",
+ "pipeline",
+ ],
+ "processing_utils": ["ProcessorMixin"],
+ "quantizers": [],
+ "testing_utils": [],
+ "tokenization_utils": ["PreTrainedTokenizer"],
+ "tokenization_utils_base": [
+ "AddedToken",
+ "BatchEncoding",
+ "CharSpan",
+ "PreTrainedTokenizerBase",
+ "SpecialTokensMixin",
+ "TokenSpan",
+ ],
+ "trainer_callback": [
+ "DefaultFlowCallback",
+ "EarlyStoppingCallback",
+ "PrinterCallback",
+ "ProgressCallback",
+ "TrainerCallback",
+ "TrainerControl",
+ "TrainerState",
+ ],
+ "trainer_utils": [
+ "EvalPrediction",
+ "IntervalStrategy",
+ "SchedulerType",
+ "enable_full_determinism",
+ "set_seed",
+ ],
+ "training_args": ["TrainingArguments"],
+ "training_args_seq2seq": ["Seq2SeqTrainingArguments"],
+ "training_args_tf": ["TFTrainingArguments"],
+ "utils": [
+ "CONFIG_NAME",
+ "MODEL_CARD_NAME",
+ "PYTORCH_PRETRAINED_BERT_CACHE",
+ "PYTORCH_TRANSFORMERS_CACHE",
+ "SPIECE_UNDERLINE",
+ "TF2_WEIGHTS_NAME",
+ "TF_WEIGHTS_NAME",
+ "TRANSFORMERS_CACHE",
+ "WEIGHTS_NAME",
+ "TensorType",
+ "add_end_docstrings",
+ "add_start_docstrings",
+ "is_apex_available",
+ "is_av_available",
+ "is_bitsandbytes_available",
+ "is_datasets_available",
+ "is_faiss_available",
+ "is_flax_available",
+ "is_keras_nlp_available",
+ "is_matplotlib_available",
+ "is_mlx_available",
+ "is_phonemizer_available",
+ "is_psutil_available",
+ "is_py3nvml_available",
+ "is_pyctcdecode_available",
+ "is_sacremoses_available",
+ "is_safetensors_available",
+ "is_scipy_available",
+ "is_sentencepiece_available",
+ "is_sklearn_available",
+ "is_speech_available",
+ "is_tensorflow_text_available",
+ "is_tf_available",
+ "is_timm_available",
+ "is_tokenizers_available",
+ "is_torch_available",
+ "is_torch_hpu_available",
+ "is_torch_mlu_available",
+ "is_torch_musa_available",
+ "is_torch_neuroncore_available",
+ "is_torch_npu_available",
+ "is_torchvision_available",
+ "is_torch_xla_available",
+ "is_torch_xpu_available",
+ "is_vision_available",
+ "logging",
+ ],
+ "utils.quantization_config": [
+ "AqlmConfig",
+ "AutoRoundConfig",
+ "AwqConfig",
+ "BitNetQuantConfig",
+ "BitsAndBytesConfig",
+ "CompressedTensorsConfig",
+ "EetqConfig",
+ "FbgemmFp8Config",
+ "FineGrainedFP8Config",
+ "GPTQConfig",
+ "HiggsConfig",
+ "HqqConfig",
+ "Mxfp4Config",
+ "QuantoConfig",
+ "QuarkConfig",
+ "FPQuantConfig",
+ "SpQRConfig",
+ "TorchAoConfig",
+ "VptqConfig",
+ ],
+ "video_utils": [],
+}
+
+# tokenizers-backed objects
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_tokenizers_objects
+
+ _import_structure["utils.dummy_tokenizers_objects"] = [
+ name for name in dir(dummy_tokenizers_objects) if not name.startswith("_")
+ ]
+else:
+ # Fast tokenizers structure
+ _import_structure["tokenization_utils_fast"] = ["PreTrainedTokenizerFast"]
+
+
+try:
+ if not (is_sentencepiece_available() and is_tokenizers_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_sentencepiece_and_tokenizers_objects
+
+ _import_structure["utils.dummy_sentencepiece_and_tokenizers_objects"] = [
+ name for name in dir(dummy_sentencepiece_and_tokenizers_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["convert_slow_tokenizer"] = [
+ "SLOW_TO_FAST_CONVERTERS",
+ "convert_slow_tokenizer",
+ ]
+
+try:
+ if not (is_mistral_common_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_mistral_common_objects
+
+ _import_structure["utils.dummy_mistral_common_objects"] = [
+ name for name in dir(dummy_mistral_common_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["tokenization_mistral_common"] = ["MistralCommonTokenizer"]
+
+# Vision-specific objects
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_vision_objects
+
+ _import_structure["utils.dummy_vision_objects"] = [
+ name for name in dir(dummy_vision_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["image_processing_base"] = ["ImageProcessingMixin"]
+ _import_structure["image_processing_utils"] = ["BaseImageProcessor"]
+ _import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
+
+try:
+ if not is_torchvision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_torchvision_objects
+
+ _import_structure["utils.dummy_torchvision_objects"] = [
+ name for name in dir(dummy_torchvision_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"]
+ _import_structure["video_processing_utils"] = ["BaseVideoProcessor"]
+
+# PyTorch-backed objects
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_pt_objects
+
+ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
+else:
+ _import_structure["model_debugging_utils"] = [
+ "model_addition_debugger_context",
+ ]
+ _import_structure["activations"] = []
+ _import_structure["cache_utils"] = [
+ "CacheLayerMixin",
+ "DynamicLayer",
+ "StaticLayer",
+ "StaticSlidingWindowLayer",
+ "SlidingWindowLayer",
+ "ChunkedSlidingLayer",
+ "QuantoQuantizedLayer",
+ "HQQQuantizedLayer",
+ "Cache",
+ "DynamicCache",
+ "EncoderDecoderCache",
+ "HQQQuantizedCache",
+ "HybridCache",
+ "HybridChunkedCache",
+ "OffloadedCache",
+ "OffloadedStaticCache",
+ "QuantizedCache",
+ "QuantoQuantizedCache",
+ "SinkCache",
+ "SlidingWindowCache",
+ "StaticCache",
+ ]
+ _import_structure["data.datasets"] = [
+ "GlueDataset",
+ "GlueDataTrainingArguments",
+ "LineByLineTextDataset",
+ "LineByLineWithRefDataset",
+ "LineByLineWithSOPTextDataset",
+ "SquadDataset",
+ "SquadDataTrainingArguments",
+ "TextDataset",
+ "TextDatasetForNextSentencePrediction",
+ ]
+ _import_structure["generation"].extend(
+ [
+ "AlternatingCodebooksLogitsProcessor",
+ "BayesianDetectorConfig",
+ "BayesianDetectorModel",
+ "BeamScorer",
+ "ClassifierFreeGuidanceLogitsProcessor",
+ "ConstrainedBeamSearchScorer",
+ "Constraint",
+ "ConstraintListState",
+ "DisjunctiveConstraint",
+ "EncoderNoRepeatNGramLogitsProcessor",
+ "EncoderRepetitionPenaltyLogitsProcessor",
+ "EosTokenCriteria",
+ "EpsilonLogitsWarper",
+ "EtaLogitsWarper",
+ "ExponentialDecayLengthPenalty",
+ "ForcedBOSTokenLogitsProcessor",
+ "ForcedEOSTokenLogitsProcessor",
+ "GenerationMixin",
+ "InfNanRemoveLogitsProcessor",
+ "LogitNormalization",
+ "LogitsProcessor",
+ "LogitsProcessorList",
+ "MaxLengthCriteria",
+ "MaxTimeCriteria",
+ "MinLengthLogitsProcessor",
+ "MinNewTokensLengthLogitsProcessor",
+ "MinPLogitsWarper",
+ "NoBadWordsLogitsProcessor",
+ "NoRepeatNGramLogitsProcessor",
+ "PhrasalConstraint",
+ "PrefixConstrainedLogitsProcessor",
+ "RepetitionPenaltyLogitsProcessor",
+ "SequenceBiasLogitsProcessor",
+ "StoppingCriteria",
+ "StoppingCriteriaList",
+ "StopStringCriteria",
+ "SuppressTokensAtBeginLogitsProcessor",
+ "SuppressTokensLogitsProcessor",
+ "SynthIDTextWatermarkDetector",
+ "SynthIDTextWatermarkingConfig",
+ "SynthIDTextWatermarkLogitsProcessor",
+ "TemperatureLogitsWarper",
+ "TopKLogitsWarper",
+ "TopPLogitsWarper",
+ "TypicalLogitsWarper",
+ "UnbatchedClassifierFreeGuidanceLogitsProcessor",
+ "WatermarkDetector",
+ "WatermarkLogitsProcessor",
+ "WhisperTimeStampLogitsProcessor",
+ ]
+ )
+
+ # PyTorch domain libraries integration
+ _import_structure["integrations.executorch"] = [
+ "TorchExportableModuleWithStaticCache",
+ "convert_and_export_with_cache",
+ ]
+
+ _import_structure["modeling_flash_attention_utils"] = []
+ _import_structure["modeling_layers"] = ["GradientCheckpointingLayer"]
+ _import_structure["modeling_outputs"] = []
+ _import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "dynamic_rope_update"]
+ _import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"]
+ _import_structure["masking_utils"] = ["AttentionMaskInterface"]
+ _import_structure["optimization"] = [
+ "Adafactor",
+ "get_constant_schedule",
+ "get_constant_schedule_with_warmup",
+ "get_cosine_schedule_with_warmup",
+ "get_cosine_with_hard_restarts_schedule_with_warmup",
+ "get_cosine_with_min_lr_schedule_with_warmup",
+ "get_cosine_with_min_lr_schedule_with_warmup_lr_rate",
+ "get_inverse_sqrt_schedule",
+ "get_linear_schedule_with_warmup",
+ "get_polynomial_decay_schedule_with_warmup",
+ "get_scheduler",
+ "get_wsd_schedule",
+ "get_reduce_on_plateau_schedule",
+ ]
+ _import_structure["pytorch_utils"] = [
+ "Conv1D",
+ "apply_chunking_to_forward",
+ "prune_layer",
+ "infer_device",
+ ]
+ _import_structure["sagemaker"] = []
+ _import_structure["time_series_utils"] = []
+ _import_structure["trainer"] = ["Trainer"]
+ _import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"]
+ _import_structure["trainer_seq2seq"] = ["Seq2SeqTrainer"]
+
+# TensorFlow-backed objects
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_tf_objects
+
+ _import_structure["utils.dummy_tf_objects"] = [name for name in dir(dummy_tf_objects) if not name.startswith("_")]
+else:
+ _import_structure["activations_tf"] = []
+ _import_structure["generation"].extend(
+ [
+ "TFForcedBOSTokenLogitsProcessor",
+ "TFForcedEOSTokenLogitsProcessor",
+ "TFForceTokensLogitsProcessor",
+ "TFGenerationMixin",
+ "TFLogitsProcessor",
+ "TFLogitsProcessorList",
+ "TFLogitsWarper",
+ "TFMinLengthLogitsProcessor",
+ "TFNoBadWordsLogitsProcessor",
+ "TFNoRepeatNGramLogitsProcessor",
+ "TFRepetitionPenaltyLogitsProcessor",
+ "TFSuppressTokensAtBeginLogitsProcessor",
+ "TFSuppressTokensLogitsProcessor",
+ "TFTemperatureLogitsWarper",
+ "TFTopKLogitsWarper",
+ "TFTopPLogitsWarper",
+ ]
+ )
+ _import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
+ _import_structure["modeling_tf_outputs"] = []
+ _import_structure["modeling_tf_utils"] = [
+ "TFPreTrainedModel",
+ "TFSequenceSummary",
+ "TFSharedEmbeddings",
+ "shape_list",
+ ]
+ _import_structure["optimization_tf"] = [
+ "AdamWeightDecay",
+ "GradientAccumulator",
+ "WarmUp",
+ "create_optimizer",
+ ]
+ _import_structure["tf_utils"] = []
+
+
+# FLAX-backed objects
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_flax_objects
+
+ _import_structure["utils.dummy_flax_objects"] = [
+ name for name in dir(dummy_flax_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["generation"].extend(
+ [
+ "FlaxForcedBOSTokenLogitsProcessor",
+ "FlaxForcedEOSTokenLogitsProcessor",
+ "FlaxForceTokensLogitsProcessor",
+ "FlaxGenerationMixin",
+ "FlaxLogitsProcessor",
+ "FlaxLogitsProcessorList",
+ "FlaxLogitsWarper",
+ "FlaxMinLengthLogitsProcessor",
+ "FlaxTemperatureLogitsWarper",
+ "FlaxSuppressTokensAtBeginLogitsProcessor",
+ "FlaxSuppressTokensLogitsProcessor",
+ "FlaxTopKLogitsWarper",
+ "FlaxTopPLogitsWarper",
+ "FlaxWhisperTimeStampLogitsProcessor",
+ ]
+ )
+ _import_structure["modeling_flax_outputs"] = []
+ _import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
+
+# Direct imports for type-checking
+if TYPE_CHECKING:
+ # All modeling imports
+ from .cache_utils import Cache as Cache
+ from .cache_utils import ChunkedSlidingLayer as ChunkedSlidingLayer
+ from .cache_utils import DynamicCache as DynamicCache
+ from .cache_utils import DynamicLayer as DynamicLayer
+ from .cache_utils import EncoderDecoderCache as EncoderDecoderCache
+ from .cache_utils import HQQQuantizedCache as HQQQuantizedCache
+ from .cache_utils import HQQQuantizedLayer as HQQQuantizedLayer
+ from .cache_utils import HybridCache as HybridCache
+ from .cache_utils import OffloadedCache as OffloadedCache
+ from .cache_utils import OffloadedStaticCache as OffloadedStaticCache
+ from .cache_utils import QuantizedCache as QuantizedCache
+ from .cache_utils import QuantoQuantizedCache as QuantoQuantizedCache
+ from .cache_utils import QuantoQuantizedLayer as QuantoQuantizedLayer
+ from .cache_utils import SinkCache as SinkCache
+ from .cache_utils import SlidingWindowCache as SlidingWindowCache
+ from .cache_utils import SlidingWindowLayer as SlidingWindowLayer
+ from .cache_utils import StaticCache as StaticCache
+ from .cache_utils import StaticLayer as StaticLayer
+ from .cache_utils import StaticSlidingWindowLayer as StaticSlidingWindowLayer
+ from .configuration_utils import PretrainedConfig as PretrainedConfig
+ from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS as SLOW_TO_FAST_CONVERTERS
+ from .convert_slow_tokenizer import convert_slow_tokenizer as convert_slow_tokenizer
+
+ # Data
+ from .data import DataProcessor as DataProcessor
+ from .data import InputExample as InputExample
+ from .data import InputFeatures as InputFeatures
+ from .data import SingleSentenceClassificationProcessor as SingleSentenceClassificationProcessor
+ from .data import SquadExample as SquadExample
+ from .data import SquadFeatures as SquadFeatures
+ from .data import SquadV1Processor as SquadV1Processor
+ from .data import SquadV2Processor as SquadV2Processor
+ from .data import glue_compute_metrics as glue_compute_metrics
+ from .data import glue_convert_examples_to_features as glue_convert_examples_to_features
+ from .data import glue_output_modes as glue_output_modes
+ from .data import glue_processors as glue_processors
+ from .data import glue_tasks_num_labels as glue_tasks_num_labels
+ from .data import squad_convert_examples_to_features as squad_convert_examples_to_features
+ from .data import xnli_compute_metrics as xnli_compute_metrics
+ from .data import xnli_output_modes as xnli_output_modes
+ from .data import xnli_processors as xnli_processors
+ from .data import xnli_tasks_num_labels as xnli_tasks_num_labels
+ from .data.data_collator import DataCollator as DataCollator
+ from .data.data_collator import DataCollatorForLanguageModeling as DataCollatorForLanguageModeling
+ from .data.data_collator import DataCollatorForMultipleChoice as DataCollatorForMultipleChoice
+ from .data.data_collator import (
+ DataCollatorForPermutationLanguageModeling as DataCollatorForPermutationLanguageModeling,
+ )
+ from .data.data_collator import DataCollatorForSeq2Seq as DataCollatorForSeq2Seq
+ from .data.data_collator import DataCollatorForSOP as DataCollatorForSOP
+ from .data.data_collator import DataCollatorForTokenClassification as DataCollatorForTokenClassification
+ from .data.data_collator import DataCollatorForWholeWordMask as DataCollatorForWholeWordMask
+ from .data.data_collator import DataCollatorWithFlattening as DataCollatorWithFlattening
+ from .data.data_collator import DataCollatorWithPadding as DataCollatorWithPadding
+ from .data.data_collator import DefaultDataCollator as DefaultDataCollator
+ from .data.data_collator import default_data_collator as default_data_collator
+ from .data.datasets import GlueDataset as GlueDataset
+ from .data.datasets import GlueDataTrainingArguments as GlueDataTrainingArguments
+ from .data.datasets import LineByLineTextDataset as LineByLineTextDataset
+ from .data.datasets import LineByLineWithRefDataset as LineByLineWithRefDataset
+ from .data.datasets import LineByLineWithSOPTextDataset as LineByLineWithSOPTextDataset
+ from .data.datasets import SquadDataset as SquadDataset
+ from .data.datasets import SquadDataTrainingArguments as SquadDataTrainingArguments
+ from .data.datasets import TextDataset as TextDataset
+ from .data.datasets import TextDatasetForNextSentencePrediction as TextDatasetForNextSentencePrediction
+ from .feature_extraction_sequence_utils import SequenceFeatureExtractor as SequenceFeatureExtractor
+
+ # Feature Extractor
+ from .feature_extraction_utils import BatchFeature as BatchFeature
+ from .feature_extraction_utils import FeatureExtractionMixin as FeatureExtractionMixin
+
+ # Generation
+ from .generation import AlternatingCodebooksLogitsProcessor as AlternatingCodebooksLogitsProcessor
+ from .generation import AsyncTextIteratorStreamer as AsyncTextIteratorStreamer
+ from .generation import BayesianDetectorConfig as BayesianDetectorConfig
+ from .generation import BayesianDetectorModel as BayesianDetectorModel
+ from .generation import BeamScorer as BeamScorer
+ from .generation import ClassifierFreeGuidanceLogitsProcessor as ClassifierFreeGuidanceLogitsProcessor
+ from .generation import CompileConfig as CompileConfig
+ from .generation import ConstrainedBeamSearchScorer as ConstrainedBeamSearchScorer
+ from .generation import Constraint as Constraint
+ from .generation import ConstraintListState as ConstraintListState
+ from .generation import DisjunctiveConstraint as DisjunctiveConstraint
+ from .generation import EncoderNoRepeatNGramLogitsProcessor as EncoderNoRepeatNGramLogitsProcessor
+ from .generation import EncoderRepetitionPenaltyLogitsProcessor as EncoderRepetitionPenaltyLogitsProcessor
+ from .generation import EosTokenCriteria as EosTokenCriteria
+ from .generation import EpsilonLogitsWarper as EpsilonLogitsWarper
+ from .generation import EtaLogitsWarper as EtaLogitsWarper
+ from .generation import ExponentialDecayLengthPenalty as ExponentialDecayLengthPenalty
+ from .generation import FlaxForcedBOSTokenLogitsProcessor as FlaxForcedBOSTokenLogitsProcessor
+ from .generation import FlaxForcedEOSTokenLogitsProcessor as FlaxForcedEOSTokenLogitsProcessor
+ from .generation import FlaxForceTokensLogitsProcessor as FlaxForceTokensLogitsProcessor
+ from .generation import FlaxGenerationMixin as FlaxGenerationMixin
+ from .generation import FlaxLogitsProcessor as FlaxLogitsProcessor
+ from .generation import FlaxLogitsProcessorList as FlaxLogitsProcessorList
+ from .generation import FlaxLogitsWarper as FlaxLogitsWarper
+ from .generation import FlaxMinLengthLogitsProcessor as FlaxMinLengthLogitsProcessor
+ from .generation import FlaxSuppressTokensAtBeginLogitsProcessor as FlaxSuppressTokensAtBeginLogitsProcessor
+ from .generation import FlaxSuppressTokensLogitsProcessor as FlaxSuppressTokensLogitsProcessor
+ from .generation import FlaxTemperatureLogitsWarper as FlaxTemperatureLogitsWarper
+ from .generation import FlaxTopKLogitsWarper as FlaxTopKLogitsWarper
+ from .generation import FlaxTopPLogitsWarper as FlaxTopPLogitsWarper
+ from .generation import FlaxWhisperTimeStampLogitsProcessor as FlaxWhisperTimeStampLogitsProcessor
+ from .generation import ForcedBOSTokenLogitsProcessor as ForcedBOSTokenLogitsProcessor
+ from .generation import ForcedEOSTokenLogitsProcessor as ForcedEOSTokenLogitsProcessor
+ from .generation import GenerationConfig as GenerationConfig
+ from .generation import GenerationMixin as GenerationMixin
+ from .generation import InfNanRemoveLogitsProcessor as InfNanRemoveLogitsProcessor
+ from .generation import LogitNormalization as LogitNormalization
+ from .generation import LogitsProcessor as LogitsProcessor
+ from .generation import LogitsProcessorList as LogitsProcessorList
+ from .generation import MaxLengthCriteria as MaxLengthCriteria
+ from .generation import MaxTimeCriteria as MaxTimeCriteria
+ from .generation import MinLengthLogitsProcessor as MinLengthLogitsProcessor
+ from .generation import MinNewTokensLengthLogitsProcessor as MinNewTokensLengthLogitsProcessor
+ from .generation import MinPLogitsWarper as MinPLogitsWarper
+ from .generation import NoBadWordsLogitsProcessor as NoBadWordsLogitsProcessor
+ from .generation import NoRepeatNGramLogitsProcessor as NoRepeatNGramLogitsProcessor
+ from .generation import PhrasalConstraint as PhrasalConstraint
+ from .generation import PrefixConstrainedLogitsProcessor as PrefixConstrainedLogitsProcessor
+ from .generation import RepetitionPenaltyLogitsProcessor as RepetitionPenaltyLogitsProcessor
+ from .generation import SequenceBiasLogitsProcessor as SequenceBiasLogitsProcessor
+ from .generation import StoppingCriteria as StoppingCriteria
+ from .generation import StoppingCriteriaList as StoppingCriteriaList
+ from .generation import StopStringCriteria as StopStringCriteria
+ from .generation import SuppressTokensAtBeginLogitsProcessor as SuppressTokensAtBeginLogitsProcessor
+ from .generation import SuppressTokensLogitsProcessor as SuppressTokensLogitsProcessor
+ from .generation import SynthIDTextWatermarkDetector as SynthIDTextWatermarkDetector
+ from .generation import SynthIDTextWatermarkingConfig as SynthIDTextWatermarkingConfig
+ from .generation import SynthIDTextWatermarkLogitsProcessor as SynthIDTextWatermarkLogitsProcessor
+ from .generation import TemperatureLogitsWarper as TemperatureLogitsWarper
+ from .generation import TextIteratorStreamer as TextIteratorStreamer
+ from .generation import TextStreamer as TextStreamer
+ from .generation import TFForcedBOSTokenLogitsProcessor as TFForcedBOSTokenLogitsProcessor
+ from .generation import TFForcedEOSTokenLogitsProcessor as TFForcedEOSTokenLogitsProcessor
+ from .generation import TFForceTokensLogitsProcessor as TFForceTokensLogitsProcessor
+ from .generation import TFGenerationMixin as TFGenerationMixin
+ from .generation import TFLogitsProcessor as TFLogitsProcessor
+ from .generation import TFLogitsProcessorList as TFLogitsProcessorList
+ from .generation import TFLogitsWarper as TFLogitsWarper
+ from .generation import TFMinLengthLogitsProcessor as TFMinLengthLogitsProcessor
+ from .generation import TFNoBadWordsLogitsProcessor as TFNoBadWordsLogitsProcessor
+ from .generation import TFNoRepeatNGramLogitsProcessor as TFNoRepeatNGramLogitsProcessor
+ from .generation import TFRepetitionPenaltyLogitsProcessor as TFRepetitionPenaltyLogitsProcessor
+ from .generation import TFSuppressTokensAtBeginLogitsProcessor as TFSuppressTokensAtBeginLogitsProcessor
+ from .generation import TFSuppressTokensLogitsProcessor as TFSuppressTokensLogitsProcessor
+ from .generation import TFTemperatureLogitsWarper as TFTemperatureLogitsWarper
+ from .generation import TFTopKLogitsWarper as TFTopKLogitsWarper
+ from .generation import TFTopPLogitsWarper as TFTopPLogitsWarper
+ from .generation import TopKLogitsWarper as TopKLogitsWarper
+ from .generation import TopPLogitsWarper as TopPLogitsWarper
+ from .generation import TypicalLogitsWarper as TypicalLogitsWarper
+ from .generation import (
+ UnbatchedClassifierFreeGuidanceLogitsProcessor as UnbatchedClassifierFreeGuidanceLogitsProcessor,
+ )
+ from .generation import WatermarkDetector as WatermarkDetector
+ from .generation import WatermarkingConfig as WatermarkingConfig
+ from .generation import WatermarkLogitsProcessor as WatermarkLogitsProcessor
+ from .generation import WhisperTimeStampLogitsProcessor as WhisperTimeStampLogitsProcessor
+ from .hf_argparser import HfArgumentParser as HfArgumentParser
+ from .image_processing_base import ImageProcessingMixin as ImageProcessingMixin
+ from .image_processing_utils import BaseImageProcessor as BaseImageProcessor
+ from .image_processing_utils_fast import BaseImageProcessorFast as BaseImageProcessorFast
+ from .image_utils import ImageFeatureExtractionMixin as ImageFeatureExtractionMixin
+
+ # Integrations
+ from .integrations import is_clearml_available as is_clearml_available
+ from .integrations import is_comet_available as is_comet_available
+ from .integrations import is_dvclive_available as is_dvclive_available
+ from .integrations import is_neptune_available as is_neptune_available
+ from .integrations import is_optuna_available as is_optuna_available
+ from .integrations import is_ray_available as is_ray_available
+ from .integrations import is_ray_tune_available as is_ray_tune_available
+ from .integrations import is_sigopt_available as is_sigopt_available
+ from .integrations import is_swanlab_available as is_swanlab_available
+ from .integrations import is_tensorboard_available as is_tensorboard_available
+ from .integrations import is_trackio_available as is_trackio_available
+ from .integrations import is_wandb_available as is_wandb_available
+ from .integrations.executorch import TorchExportableModuleWithStaticCache as TorchExportableModuleWithStaticCache
+ from .integrations.executorch import convert_and_export_with_cache as convert_and_export_with_cache
+ from .keras_callbacks import KerasMetricCallback as KerasMetricCallback
+ from .keras_callbacks import PushToHubCallback as PushToHubCallback
+ from .masking_utils import AttentionMaskInterface as AttentionMaskInterface
+ from .model_debugging_utils import model_addition_debugger_context as model_addition_debugger_context
+
+ # Model Cards
+ from .modelcard import ModelCard as ModelCard
+ from .modeling_flax_utils import FlaxPreTrainedModel as FlaxPreTrainedModel
+ from .modeling_layers import GradientCheckpointingLayer as GradientCheckpointingLayer
+ from .modeling_rope_utils import ROPE_INIT_FUNCTIONS as ROPE_INIT_FUNCTIONS
+ from .modeling_rope_utils import dynamic_rope_update as dynamic_rope_update
+
+ # TF 2.0 <=> PyTorch conversion utilities
+ from .modeling_tf_pytorch_utils import (
+ convert_tf_weight_name_to_pt_weight_name as convert_tf_weight_name_to_pt_weight_name,
+ )
+ from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model as load_pytorch_checkpoint_in_tf2_model
+ from .modeling_tf_pytorch_utils import load_pytorch_model_in_tf2_model as load_pytorch_model_in_tf2_model
+ from .modeling_tf_pytorch_utils import load_pytorch_weights_in_tf2_model as load_pytorch_weights_in_tf2_model
+ from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model as load_tf2_checkpoint_in_pytorch_model
+ from .modeling_tf_pytorch_utils import load_tf2_model_in_pytorch_model as load_tf2_model_in_pytorch_model
+ from .modeling_tf_pytorch_utils import load_tf2_weights_in_pytorch_model as load_tf2_weights_in_pytorch_model
+ from .modeling_tf_utils import TFPreTrainedModel as TFPreTrainedModel
+ from .modeling_tf_utils import TFSequenceSummary as TFSequenceSummary
+ from .modeling_tf_utils import TFSharedEmbeddings as TFSharedEmbeddings
+ from .modeling_tf_utils import shape_list as shape_list
+ from .modeling_utils import AttentionInterface as AttentionInterface
+ from .modeling_utils import PreTrainedModel as PreTrainedModel
+ from .models import *
+ from .models.mamba.modeling_mamba import MambaCache as MambaCache
+ from .models.timm_wrapper import TimmWrapperImageProcessor as TimmWrapperImageProcessor
+
+ # Optimization
+ from .optimization import Adafactor as Adafactor
+ from .optimization import get_constant_schedule as get_constant_schedule
+ from .optimization import get_constant_schedule_with_warmup as get_constant_schedule_with_warmup
+ from .optimization import get_cosine_schedule_with_warmup as get_cosine_schedule_with_warmup
+ from .optimization import (
+ get_cosine_with_hard_restarts_schedule_with_warmup as get_cosine_with_hard_restarts_schedule_with_warmup,
+ )
+ from .optimization import (
+ get_cosine_with_min_lr_schedule_with_warmup as get_cosine_with_min_lr_schedule_with_warmup,
+ )
+ from .optimization import (
+ get_cosine_with_min_lr_schedule_with_warmup_lr_rate as get_cosine_with_min_lr_schedule_with_warmup_lr_rate,
+ )
+ from .optimization import get_inverse_sqrt_schedule as get_inverse_sqrt_schedule
+ from .optimization import get_linear_schedule_with_warmup as get_linear_schedule_with_warmup
+ from .optimization import get_polynomial_decay_schedule_with_warmup as get_polynomial_decay_schedule_with_warmup
+ from .optimization import get_scheduler as get_scheduler
+ from .optimization import get_wsd_schedule as get_wsd_schedule
+
+ # Optimization
+ from .optimization_tf import AdamWeightDecay as AdamWeightDecay
+ from .optimization_tf import GradientAccumulator as GradientAccumulator
+ from .optimization_tf import WarmUp as WarmUp
+ from .optimization_tf import create_optimizer as create_optimizer
+
+ # Pipelines
+ from .pipelines import AudioClassificationPipeline as AudioClassificationPipeline
+ from .pipelines import AutomaticSpeechRecognitionPipeline as AutomaticSpeechRecognitionPipeline
+ from .pipelines import CsvPipelineDataFormat as CsvPipelineDataFormat
+ from .pipelines import DepthEstimationPipeline as DepthEstimationPipeline
+ from .pipelines import DocumentQuestionAnsweringPipeline as DocumentQuestionAnsweringPipeline
+ from .pipelines import FeatureExtractionPipeline as FeatureExtractionPipeline
+ from .pipelines import FillMaskPipeline as FillMaskPipeline
+ from .pipelines import ImageClassificationPipeline as ImageClassificationPipeline
+ from .pipelines import ImageFeatureExtractionPipeline as ImageFeatureExtractionPipeline
+ from .pipelines import ImageSegmentationPipeline as ImageSegmentationPipeline
+ from .pipelines import ImageTextToTextPipeline as ImageTextToTextPipeline
+ from .pipelines import ImageToImagePipeline as ImageToImagePipeline
+ from .pipelines import ImageToTextPipeline as ImageToTextPipeline
+ from .pipelines import JsonPipelineDataFormat as JsonPipelineDataFormat
+ from .pipelines import KeypointMatchingPipeline as KeypointMatchingPipeline
+ from .pipelines import MaskGenerationPipeline as MaskGenerationPipeline
+ from .pipelines import NerPipeline as NerPipeline
+ from .pipelines import ObjectDetectionPipeline as ObjectDetectionPipeline
+ from .pipelines import PipedPipelineDataFormat as PipedPipelineDataFormat
+ from .pipelines import Pipeline as Pipeline
+ from .pipelines import PipelineDataFormat as PipelineDataFormat
+ from .pipelines import QuestionAnsweringPipeline as QuestionAnsweringPipeline
+ from .pipelines import SummarizationPipeline as SummarizationPipeline
+ from .pipelines import TableQuestionAnsweringPipeline as TableQuestionAnsweringPipeline
+ from .pipelines import Text2TextGenerationPipeline as Text2TextGenerationPipeline
+ from .pipelines import TextClassificationPipeline as TextClassificationPipeline
+ from .pipelines import TextGenerationPipeline as TextGenerationPipeline
+ from .pipelines import TextToAudioPipeline as TextToAudioPipeline
+ from .pipelines import TokenClassificationPipeline as TokenClassificationPipeline
+ from .pipelines import TranslationPipeline as TranslationPipeline
+ from .pipelines import VideoClassificationPipeline as VideoClassificationPipeline
+ from .pipelines import VisualQuestionAnsweringPipeline as VisualQuestionAnsweringPipeline
+ from .pipelines import ZeroShotAudioClassificationPipeline as ZeroShotAudioClassificationPipeline
+ from .pipelines import ZeroShotClassificationPipeline as ZeroShotClassificationPipeline
+ from .pipelines import ZeroShotImageClassificationPipeline as ZeroShotImageClassificationPipeline
+ from .pipelines import ZeroShotObjectDetectionPipeline as ZeroShotObjectDetectionPipeline
+ from .pipelines import pipeline as pipeline
+ from .processing_utils import ProcessorMixin as ProcessorMixin
+ from .pytorch_utils import Conv1D as Conv1D
+ from .pytorch_utils import apply_chunking_to_forward as apply_chunking_to_forward
+ from .pytorch_utils import prune_layer as prune_layer
+
+ # Tokenization
+ from .tokenization_utils import PreTrainedTokenizer as PreTrainedTokenizer
+ from .tokenization_utils_base import AddedToken as AddedToken
+ from .tokenization_utils_base import BatchEncoding as BatchEncoding
+ from .tokenization_utils_base import CharSpan as CharSpan
+ from .tokenization_utils_base import PreTrainedTokenizerBase as PreTrainedTokenizerBase
+ from .tokenization_utils_base import SpecialTokensMixin as SpecialTokensMixin
+ from .tokenization_utils_base import TokenSpan as TokenSpan
+ from .tokenization_utils_fast import PreTrainedTokenizerFast as PreTrainedTokenizerFast
+
+ # Trainer
+ from .trainer import Trainer as Trainer
+
+ # Trainer
+ from .trainer_callback import DefaultFlowCallback as DefaultFlowCallback
+ from .trainer_callback import EarlyStoppingCallback as EarlyStoppingCallback
+ from .trainer_callback import PrinterCallback as PrinterCallback
+ from .trainer_callback import ProgressCallback as ProgressCallback
+ from .trainer_callback import TrainerCallback as TrainerCallback
+ from .trainer_callback import TrainerControl as TrainerControl
+ from .trainer_callback import TrainerState as TrainerState
+ from .trainer_pt_utils import torch_distributed_zero_first as torch_distributed_zero_first
+ from .trainer_seq2seq import Seq2SeqTrainer as Seq2SeqTrainer
+ from .trainer_utils import EvalPrediction as EvalPrediction
+ from .trainer_utils import IntervalStrategy as IntervalStrategy
+ from .trainer_utils import SchedulerType as SchedulerType
+ from .trainer_utils import enable_full_determinism as enable_full_determinism
+ from .trainer_utils import set_seed as set_seed
+ from .training_args import TrainingArguments as TrainingArguments
+ from .training_args_seq2seq import Seq2SeqTrainingArguments as Seq2SeqTrainingArguments
+ from .training_args_tf import TFTrainingArguments as TFTrainingArguments
+
+ # Files and general utilities
+ from .utils import CONFIG_NAME as CONFIG_NAME
+ from .utils import MODEL_CARD_NAME as MODEL_CARD_NAME
+ from .utils import PYTORCH_PRETRAINED_BERT_CACHE as PYTORCH_PRETRAINED_BERT_CACHE
+ from .utils import PYTORCH_TRANSFORMERS_CACHE as PYTORCH_TRANSFORMERS_CACHE
+ from .utils import SPIECE_UNDERLINE as SPIECE_UNDERLINE
+ from .utils import TF2_WEIGHTS_NAME as TF2_WEIGHTS_NAME
+ from .utils import TF_WEIGHTS_NAME as TF_WEIGHTS_NAME
+ from .utils import TRANSFORMERS_CACHE as TRANSFORMERS_CACHE
+ from .utils import WEIGHTS_NAME as WEIGHTS_NAME
+ from .utils import TensorType as TensorType
+ from .utils import add_end_docstrings as add_end_docstrings
+ from .utils import add_start_docstrings as add_start_docstrings
+ from .utils import is_apex_available as is_apex_available
+ from .utils import is_av_available as is_av_available
+ from .utils import is_datasets_available as is_datasets_available
+ from .utils import is_faiss_available as is_faiss_available
+ from .utils import is_matplotlib_available as is_matplotlib_available
+ from .utils import is_phonemizer_available as is_phonemizer_available
+ from .utils import is_psutil_available as is_psutil_available
+ from .utils import is_py3nvml_available as is_py3nvml_available
+ from .utils import is_pyctcdecode_available as is_pyctcdecode_available
+ from .utils import is_sacremoses_available as is_sacremoses_available
+ from .utils import is_safetensors_available as is_safetensors_available
+ from .utils import is_sklearn_available as is_sklearn_available
+ from .utils import is_torch_hpu_available as is_torch_hpu_available
+ from .utils import is_torch_mlu_available as is_torch_mlu_available
+ from .utils import is_torch_musa_available as is_torch_musa_available
+ from .utils import is_torch_neuroncore_available as is_torch_neuroncore_available
+ from .utils import is_torch_npu_available as is_torch_npu_available
+ from .utils import is_torch_xla_available as is_torch_xla_available
+ from .utils import is_torch_xpu_available as is_torch_xpu_available
+
+ # bitsandbytes config
+ from .utils.quantization_config import AqlmConfig as AqlmConfig
+ from .utils.quantization_config import AutoRoundConfig as AutoRoundConfig
+ from .utils.quantization_config import AwqConfig as AwqConfig
+ from .utils.quantization_config import BitNetQuantConfig as BitNetQuantConfig
+ from .utils.quantization_config import BitsAndBytesConfig as BitsAndBytesConfig
+ from .utils.quantization_config import CompressedTensorsConfig as CompressedTensorsConfig
+ from .utils.quantization_config import EetqConfig as EetqConfig
+ from .utils.quantization_config import FbgemmFp8Config as FbgemmFp8Config
+ from .utils.quantization_config import FineGrainedFP8Config as FineGrainedFP8Config
+ from .utils.quantization_config import FPQuantConfig as FPQuantConfig
+ from .utils.quantization_config import GPTQConfig as GPTQConfig
+ from .utils.quantization_config import HiggsConfig as HiggsConfig
+ from .utils.quantization_config import HqqConfig as HqqConfig
+ from .utils.quantization_config import QuantoConfig as QuantoConfig
+ from .utils.quantization_config import QuarkConfig as QuarkConfig
+ from .utils.quantization_config import SpQRConfig as SpQRConfig
+ from .utils.quantization_config import TorchAoConfig as TorchAoConfig
+ from .utils.quantization_config import VptqConfig as VptqConfig
+ from .video_processing_utils import BaseVideoProcessor as BaseVideoProcessor
+
+else:
+ import sys
+
+ _import_structure = {k: set(v) for k, v in _import_structure.items()}
+
+ import_structure = define_import_structure(Path(__file__).parent / "models", prefix="models")
+ import_structure[frozenset({})].update(_import_structure)
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ import_structure,
+ module_spec=__spec__,
+ extra_objects={"__version__": __version__},
+ )
+
+
+if not is_tf_available() and not is_torch_available() and not is_flax_available():
+ logger.warning_advice(
+ "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. "
+ "Models won't be available and only tokenizers, configuration "
+ "and file/data utilities can be used."
+ )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/activations.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/activations.py
new file mode 100644
index 0000000000000000000000000000000000000000..7642e8aa238a6df701da95f0a58bef3156baf0e0
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/activations.py
@@ -0,0 +1,356 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+import math
+from collections import OrderedDict
+
+import torch
+from torch import Tensor, nn
+
+from .integrations.hub_kernels import use_kernel_forward_from_hub
+from .utils import logging
+from .utils.import_utils import is_torchdynamo_compiling
+
+
+logger = logging.get_logger(__name__)
+
+
+@use_kernel_forward_from_hub("GeluTanh")
+class GELUTanh(nn.Module):
+ """
+ A fast C implementation of the tanh approximation of the GeLU activation function. See
+ https://huggingface.co/papers/1606.08415.
+
+ This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
+ match due to rounding errors.
+ """
+
+ def __init__(self, use_gelu_tanh_python: bool = False):
+ super().__init__()
+ if use_gelu_tanh_python:
+ self.act = self._gelu_tanh_python
+ else:
+ self.act = functools.partial(nn.functional.gelu, approximate="tanh")
+
+ def _gelu_tanh_python(self, input: Tensor) -> Tensor:
+ return input * 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
+
+ def forward(self, input: Tensor) -> Tensor:
+ return self.act(input)
+
+
+@use_kernel_forward_from_hub("NewGELU")
+class NewGELUActivation(nn.Module):
+ """
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
+ the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
+ """
+
+ def forward(self, input: Tensor) -> Tensor:
+ return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
+
+
+@use_kernel_forward_from_hub("GeLU")
+class GELUActivation(nn.Module):
+ """
+ Original Implementation of the GELU activation function in Google BERT repo when initially created. For
+ information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
+ Also see the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
+ """
+
+ def __init__(self, use_gelu_python: bool = False):
+ super().__init__()
+ if use_gelu_python:
+ self.act = self._gelu_python
+ else:
+ self.act = nn.functional.gelu
+
+ def _gelu_python(self, input: Tensor) -> Tensor:
+ return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
+
+ def forward(self, input: Tensor) -> Tensor:
+ return self.act(input)
+
+
+@use_kernel_forward_from_hub("SiLU")
+class SiLUActivation(nn.Module):
+ """
+ See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
+ Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
+ Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
+ Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
+ later.
+ """
+
+ def forward(self, input: Tensor) -> Tensor:
+ return nn.functional.silu(input)
+
+
+@use_kernel_forward_from_hub("FastGELU")
+class FastGELUActivation(nn.Module):
+ """
+ Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
+ """
+
+ def forward(self, input: Tensor) -> Tensor:
+ return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
+
+
+@use_kernel_forward_from_hub("QuickGELU")
+class QuickGELUActivation(nn.Module):
+ """
+ Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
+ """
+
+ def forward(self, input: Tensor) -> Tensor:
+ return input * torch.sigmoid(1.702 * input)
+
+
+class ClippedGELUActivation(nn.Module):
+ """
+ Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
+ it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
+ https://huggingface.co/papers/2004.09602.
+
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
+ initially created.
+
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://huggingface.co/papers/1606.08415
+ """
+
+ def __init__(self, min: float, max: float):
+ if min > max:
+ raise ValueError(f"min should be < max (got min: {min}, max: {max})")
+
+ super().__init__()
+ self.min = min
+ self.max = max
+
+ def forward(self, x: Tensor) -> Tensor:
+ return torch.clip(gelu(x), self.min, self.max)
+
+
+class AccurateGELUActivation(nn.Module):
+ """
+ Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
+ https://github.com/hendrycks/GELUs
+
+ Implemented along with MEGA (Moving Average Equipped Gated Attention)
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.precomputed_constant = math.sqrt(2 / math.pi)
+
+ def forward(self, input: Tensor) -> Tensor:
+ return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
+
+
+class MishActivation(nn.Module):
+ """
+ See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://huggingface.co/papers/1908.08681). Also
+ visit the official repository for the paper: https://github.com/digantamisra98/Mish
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.act = nn.functional.mish
+
+ def _mish_python(self, input: Tensor) -> Tensor:
+ return input * torch.tanh(nn.functional.softplus(input))
+
+ def forward(self, input: Tensor) -> Tensor:
+ return self.act(input)
+
+
+class LinearActivation(nn.Module):
+ """
+ Applies the linear activation function, i.e. forwarding input directly to output.
+ """
+
+ def forward(self, input: Tensor) -> Tensor:
+ return input
+
+
+class LaplaceActivation(nn.Module):
+ """
+ Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
+ https://huggingface.co/papers/2209.10655
+
+ Inspired by squared relu, but with bounded range and gradient for better stability
+ """
+
+ def forward(self, input, mu=0.707107, sigma=0.282095):
+ input = (input - mu).div(sigma * math.sqrt(2.0))
+ return 0.5 * (1.0 + torch.erf(input))
+
+
+class ReLUSquaredActivation(nn.Module):
+ """
+ Applies the relu^2 activation introduced in https://huggingface.co/papers/2109.08668v2
+ """
+
+ def forward(self, input):
+ relu_applied = nn.functional.relu(input)
+ squared = torch.square(relu_applied)
+ return squared
+
+
+class ClassInstantier(OrderedDict):
+ def __getitem__(self, key):
+ content = super().__getitem__(key)
+ cls, kwargs = content if isinstance(content, tuple) else (content, {})
+ return cls(**kwargs)
+
+
+class XIELUActivation(nn.Module):
+ """
+ Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
+
+ If the user has installed the nickjbrowning/XIELU wheel, we import xIELU CUDA
+ Otherwise, we emit a single warning and use xIELU Python
+ """
+
+ def __init__(
+ self,
+ alpha_p_init=0.8,
+ alpha_n_init=0.8,
+ beta=0.5,
+ eps=-1e-6,
+ dtype=torch.bfloat16,
+ with_vector_loads=False,
+ ):
+ super().__init__()
+ self.alpha_p = nn.Parameter(torch.log(torch.expm1(torch.tensor(alpha_p_init, dtype=dtype))).unsqueeze(0))
+ self.alpha_n = nn.Parameter(
+ torch.log(torch.expm1(torch.tensor(alpha_n_init - beta, dtype=dtype))).unsqueeze(0)
+ )
+ self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
+ self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
+ self.with_vector_loads = with_vector_loads
+ # Temporary until xIELU CUDA fully implemented
+ self._beta_scalar = float(self.beta.detach().cpu().float().item())
+ self._eps_scalar = float(self.eps.detach().cpu().float().item())
+
+ self._xielu_cuda_obj = None
+ try:
+ import xielu.ops # noqa: F401
+
+ self._xielu_cuda_obj = torch.classes.xielu.XIELU()
+ msg = "Using experimental xIELU CUDA."
+ try:
+ from torch._dynamo import allow_in_graph
+
+ self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
+ msg += " Enabled torch._dynamo for xIELU CUDA."
+ except Exception as err:
+ msg += f" Could not enable torch._dynamo for xIELU ({err}) - this may result in slower performance."
+ self._xielu_cuda_fn = self._xielu_cuda
+ logger.warning_once(msg)
+ except Exception as err:
+ logger.warning_once(
+ "CUDA-fused xIELU not available (%s) – falling back to a Python version.\n"
+ "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
+ str(err),
+ )
+
+ def _xielu_python(self, x: Tensor) -> Tensor:
+ alpha_p = nn.functional.softplus(self.alpha_p)
+ alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
+ return torch.where(
+ x > 0,
+ alpha_p * x * x + self.beta * x,
+ (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
+ )
+
+ def _xielu_cuda(self, x: Tensor) -> Tensor:
+ """Firewall function to prevent torch.compile from seeing .item() calls"""
+ original_shape = x.shape
+ # CUDA kernel expects 3D tensors, reshape if needed
+ while x.dim() < 3:
+ x = x.unsqueeze(0)
+ if x.dim() > 3:
+ x = x.view(-1, 1, x.size(-1))
+ if original_shape != x.shape:
+ logger.warning_once(
+ "Warning: xIELU input tensor expects 3 dimensions but got (shape: %s). Reshaping to (shape: %s).",
+ original_shape,
+ x.shape,
+ )
+ result = self._xielu_cuda_obj.forward(
+ x,
+ self.alpha_p.to(x.dtype),
+ self.alpha_n.to(x.dtype),
+ # Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
+ self._beta_scalar,
+ self._eps_scalar,
+ self.with_vector_loads,
+ )
+ return result.view(original_shape)
+
+ def forward(self, input: Tensor) -> Tensor:
+ if self._xielu_cuda_obj is not None and input.is_cuda:
+ if not is_torchdynamo_compiling():
+ return self._xielu_cuda_fn(input)
+ else:
+ logger.warning_once("torch._dynamo is compiling, using Python version of xIELU.")
+ return self._xielu_python(input)
+
+
+ACT2CLS = {
+ "gelu": GELUActivation,
+ "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
+ "gelu_fast": FastGELUActivation,
+ "gelu_new": NewGELUActivation,
+ "gelu_python": (GELUActivation, {"use_gelu_python": True}),
+ "gelu_pytorch_tanh": GELUTanh,
+ "gelu_python_tanh": (GELUTanh, {"use_gelu_tanh_python": True}),
+ "gelu_accurate": AccurateGELUActivation,
+ "laplace": LaplaceActivation,
+ "leaky_relu": nn.LeakyReLU,
+ "linear": LinearActivation,
+ "mish": MishActivation,
+ "quick_gelu": QuickGELUActivation,
+ "relu": nn.ReLU,
+ "relu2": ReLUSquaredActivation,
+ "relu6": nn.ReLU6,
+ "sigmoid": nn.Sigmoid,
+ "silu": SiLUActivation,
+ "swish": nn.SiLU,
+ "tanh": nn.Tanh,
+ "prelu": nn.PReLU,
+ "xielu": XIELUActivation,
+}
+ACT2FN = ClassInstantier(ACT2CLS)
+
+
+def get_activation(activation_string):
+ if activation_string in ACT2FN:
+ return ACT2FN[activation_string]
+ else:
+ raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
+
+
+# For backwards compatibility with: from activations import gelu_python
+gelu_python = get_activation("gelu_python")
+gelu_new = get_activation("gelu_new")
+gelu = get_activation("gelu")
+gelu_fast = get_activation("gelu_fast")
+quick_gelu = get_activation("quick_gelu")
+silu = get_activation("silu")
+mish = get_activation("mish")
+linear_act = get_activation("linear")
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/activations_tf.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/activations_tf.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dccf6c4f46b8fe1f98d7e57bd8611f660ed19f4
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/activations_tf.py
@@ -0,0 +1,147 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+import tensorflow as tf
+from packaging.version import parse
+
+
+try:
+ import tf_keras as keras
+except (ModuleNotFoundError, ImportError):
+ import keras
+
+ if parse(keras.__version__).major > 2:
+ raise ValueError(
+ "Your currently installed version of Keras is Keras 3, but this is not yet supported in "
+ "Transformers. Please install the backwards-compatible tf-keras package with "
+ "`pip install tf-keras`."
+ )
+
+
+def _gelu(x):
+ """
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
+ initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see
+ https://huggingface.co/papers/1606.08415
+ """
+ x = tf.convert_to_tensor(x)
+ cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype)))
+
+ return x * cdf
+
+
+def _gelu_new(x):
+ """
+ Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://huggingface.co/papers/1606.0841
+
+ Args:
+ x: float Tensor to perform activation
+
+ Returns:
+ `x` with the GELU activation applied.
+ """
+ x = tf.convert_to_tensor(x)
+ pi = tf.cast(math.pi, x.dtype)
+ coeff = tf.cast(0.044715, x.dtype)
+ cdf = 0.5 * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3))))
+
+ return x * cdf
+
+
+def mish(x):
+ x = tf.convert_to_tensor(x)
+
+ return x * tf.tanh(tf.math.softplus(x))
+
+
+def gelu_fast(x):
+ x = tf.convert_to_tensor(x)
+ coeff1 = tf.cast(0.044715, x.dtype)
+ coeff2 = tf.cast(0.7978845608, x.dtype)
+
+ return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x)))
+
+
+def quick_gelu(x):
+ x = tf.convert_to_tensor(x)
+ coeff = tf.cast(1.702, x.dtype)
+ return x * tf.math.sigmoid(coeff * x)
+
+
+def gelu_10(x):
+ """
+ Clip the range of possible GeLU outputs between [-10, 10]. This is especially useful for quantization purpose, as
+ it allows mapping 2 negatives values in the GeLU spectrum. For more information on this trick, please refer to
+ https://huggingface.co/papers/2004.09602
+
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
+ initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see
+ https://huggingface.co/papers/1606.08415 :param x: :return:
+ """
+ return tf.clip_by_value(_gelu(x), -10, 10)
+
+
+def glu(x, axis=-1):
+ """
+ Gated Linear Unit. Implementation as defined in the original paper (see https://huggingface.co/papers/1612.08083), where
+ the input `x` is split in two halves across a dimension (`axis`), A and B, returning A * sigmoid(B).
+
+ Args:
+ `x`: float Tensor to perform activation
+ `axis`: dimension across which `x` be split in half
+
+ Returns:
+ `x` with the GLU activation applied (with its size halved across the dimension `axis`).
+ """
+ a, b = tf.split(x, 2, axis=axis)
+ return a * tf.math.sigmoid(b)
+
+
+if parse(tf.version.VERSION) >= parse("2.4"):
+
+ def approximate_gelu_wrap(x):
+ return keras.activations.gelu(x, approximate=True)
+
+ gelu = keras.activations.gelu
+ gelu_new = approximate_gelu_wrap
+else:
+ gelu = _gelu
+ gelu_new = _gelu_new
+
+
+ACT2FN = {
+ "gelu": gelu,
+ "gelu_10": gelu_10,
+ "gelu_fast": gelu_fast,
+ "gelu_new": gelu_new,
+ "glu": glu,
+ "mish": mish,
+ "quick_gelu": quick_gelu,
+ "relu": keras.activations.relu,
+ "sigmoid": keras.activations.sigmoid,
+ "silu": keras.activations.swish,
+ "swish": keras.activations.swish,
+ "tanh": keras.activations.tanh,
+}
+
+
+def get_tf_activation(activation_string):
+ if activation_string in ACT2FN:
+ return ACT2FN[activation_string]
+ else:
+ raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/audio_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/audio_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5de56618014ecdb1443df0968018ce25dc394fc0
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/audio_utils.py
@@ -0,0 +1,1224 @@
+# Copyright 2023 The HuggingFace Inc. team and the librosa & torchaudio authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Audio processing functions to extract features from audio waveforms. This code is pure numpy to support all frameworks
+and remove unnecessary dependencies.
+"""
+
+import base64
+import importlib
+import io
+import os
+import warnings
+from collections.abc import Sequence
+from io import BytesIO
+from typing import TYPE_CHECKING, Any, Optional, Union
+
+
+if TYPE_CHECKING:
+ import torch
+import numpy as np
+import requests
+from packaging import version
+
+from .utils import (
+ is_librosa_available,
+ is_numpy_array,
+ is_soundfile_available,
+ is_torch_tensor,
+ is_torchcodec_available,
+ requires_backends,
+)
+
+
+if is_soundfile_available():
+ import soundfile as sf
+
+if is_librosa_available():
+ import librosa
+
+ # TODO: @eustlb, we actually don't need librosa but soxr is installed with librosa
+ import soxr
+
+if is_torchcodec_available():
+ TORCHCODEC_VERSION = version.parse(importlib.metadata.version("torchcodec"))
+
+AudioInput = Union[np.ndarray, "torch.Tensor", Sequence[np.ndarray], Sequence["torch.Tensor"]]
+
+
+def load_audio(audio: Union[str, np.ndarray], sampling_rate=16000, timeout=None) -> np.ndarray:
+ """
+ Loads `audio` to an np.ndarray object.
+
+ Args:
+ audio (`str` or `np.ndarray`):
+ The audio to be loaded to the numpy array format.
+ sampling_rate (`int`, *optional*, defaults to 16000):
+ The sampling rate to be used when loading the audio. It should be same as the
+ sampling rate the model you will be using further was trained with.
+ timeout (`float`, *optional*):
+ The timeout value in seconds for the URL request.
+
+ Returns:
+ `np.ndarray`: A numpy array representing the audio.
+ """
+ if isinstance(audio, str):
+ # Try to load with `torchcodec` but do not enforce users to install it. If not found
+ # fallback to `librosa`. If using an audio-only model, most probably `torchcodec` won't be
+ # needed. Do not raise any errors if not installed or versions do not match
+ if is_torchcodec_available() and TORCHCODEC_VERSION >= version.parse("0.3.0"):
+ audio = load_audio_torchcodec(audio, sampling_rate=sampling_rate)
+ else:
+ audio = load_audio_librosa(audio, sampling_rate=sampling_rate, timeout=timeout)
+ elif not isinstance(audio, np.ndarray):
+ raise TypeError(
+ "Incorrect format used for `audio`. Should be an url linking to an audio, a local path, or numpy array."
+ )
+ return audio
+
+
+def load_audio_torchcodec(audio: Union[str, np.ndarray], sampling_rate=16000) -> np.ndarray:
+ """
+ Loads `audio` to an np.ndarray object using `torchcodec`.
+
+ Args:
+ audio (`str` or `np.ndarray`):
+ The audio to be loaded to the numpy array format.
+ sampling_rate (`int`, *optional*, defaults to 16000):
+ The sampling rate to be used when loading the audio. It should be same as the
+ sampling rate the model you will be using further was trained with.
+
+ Returns:
+ `np.ndarray`: A numpy array representing the audio.
+ """
+ # Lazy import so that issues in torchcodec compatibility don't crash the whole library
+ requires_backends(load_audio_torchcodec, ["torchcodec"])
+ from torchcodec.decoders import AudioDecoder
+
+ # Set `num_channels` to `1` which is what most models expects and the default in librosa
+ decoder = AudioDecoder(audio, sample_rate=sampling_rate, num_channels=1)
+ audio = decoder.get_all_samples().data[0].numpy() # NOTE: feature extractors don't accept torch tensors
+ return audio
+
+
+def load_audio_librosa(audio: Union[str, np.ndarray], sampling_rate=16000, timeout=None) -> np.ndarray:
+ """
+ Loads `audio` to an np.ndarray object using `librosa`.
+
+ Args:
+ audio (`str` or `np.ndarray`):
+ The audio to be loaded to the numpy array format.
+ sampling_rate (`int`, *optional*, defaults to 16000):
+ The sampling rate to be used when loading the audio. It should be same as the
+ sampling rate the model you will be using further was trained with.
+ timeout (`float`, *optional*):
+ The timeout value in seconds for the URL request.
+
+ Returns:
+ `np.ndarray`: A numpy array representing the audio.
+ """
+ requires_backends(load_audio_librosa, ["librosa"])
+
+ # Load audio from URL (e.g https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/translate_to_chinese.wav)
+ if audio.startswith("http://") or audio.startswith("https://"):
+ audio = librosa.load(BytesIO(requests.get(audio, timeout=timeout).content), sr=sampling_rate)[0]
+ elif os.path.isfile(audio):
+ audio = librosa.load(audio, sr=sampling_rate)[0]
+ return audio
+
+
+def load_audio_as(
+ audio: str,
+ return_format: str,
+ timeout: Optional[int] = None,
+ force_mono: bool = False,
+ sampling_rate: Optional[int] = None,
+) -> Union[str, dict[str, Any], io.BytesIO, None]:
+ """
+ Load audio from either a local file path or URL and return in specified format.
+
+ Args:
+ audio (`str`): Either a local file path or a URL to an audio file
+ return_format (`str`): Format to return the audio in:
+ - "base64": Base64 encoded string
+ - "dict": Dictionary with data and format
+ - "buffer": BytesIO object
+ timeout (`int`, *optional*): Timeout for URL requests in seconds
+ force_mono (`bool`): Whether to convert stereo audio to mono
+ sampling_rate (`int`, *optional*): If provided, the audio will be resampled to the specified sampling rate.
+
+ Returns:
+ `Union[str, Dict[str, Any], io.BytesIO, None]`:
+ - `str`: Base64 encoded audio data (if return_format="base64")
+ - `dict`: Dictionary with 'data' (base64 encoded audio data) and 'format' keys (if return_format="dict")
+ - `io.BytesIO`: BytesIO object containing audio data (if return_format="buffer")
+ """
+ # TODO: @eustlb, we actually don't need librosa but soxr is installed with librosa
+ requires_backends(load_audio_as, ["librosa"])
+
+ if return_format not in ["base64", "dict", "buffer"]:
+ raise ValueError(f"Invalid return_format: {return_format}. Must be 'base64', 'dict', or 'buffer'")
+
+ try:
+ # Load audio bytes from URL or file
+ audio_bytes = None
+ if audio.startswith(("http://", "https://")):
+ response = requests.get(audio, timeout=timeout)
+ response.raise_for_status()
+ audio_bytes = response.content
+ elif os.path.isfile(audio):
+ with open(audio, "rb") as audio_file:
+ audio_bytes = audio_file.read()
+ else:
+ raise ValueError(f"File not found: {audio}")
+
+ # Process audio data
+ with io.BytesIO(audio_bytes) as audio_file:
+ with sf.SoundFile(audio_file) as f:
+ audio_array = f.read(dtype="float32")
+ original_sr = f.samplerate
+ audio_format = f.format
+ if sampling_rate is not None and sampling_rate != original_sr:
+ # Resample audio to target sampling rate
+ audio_array = soxr.resample(audio_array, original_sr, sampling_rate, quality="HQ")
+ else:
+ sampling_rate = original_sr
+
+ # Convert to mono if needed
+ if force_mono and audio_array.ndim != 1:
+ audio_array = audio_array.mean(axis=1)
+
+ buffer = io.BytesIO()
+ sf.write(buffer, audio_array, sampling_rate, format=audio_format.upper())
+ buffer.seek(0)
+
+ if return_format == "buffer":
+ return buffer
+ elif return_format == "base64":
+ return base64.b64encode(buffer.read()).decode("utf-8")
+ elif return_format == "dict":
+ return {
+ "data": base64.b64encode(buffer.read()).decode("utf-8"),
+ "format": audio_format.lower(),
+ }
+
+ except Exception as e:
+ raise ValueError(f"Error loading audio: {e}")
+
+
+def is_valid_audio(audio):
+ return is_numpy_array(audio) or is_torch_tensor(audio)
+
+
+def is_valid_list_of_audio(audio):
+ return audio and all(is_valid_audio(audio_i) for audio_i in audio)
+
+
+def make_list_of_audio(
+ audio: Union[list[AudioInput], AudioInput],
+) -> AudioInput:
+ """
+ Ensure that the output is a list of audio.
+ Args:
+ audio (`Union[list[AudioInput], AudioInput]`):
+ The input audio.
+ Returns:
+ list: A list of audio.
+ """
+ # If it's a list of audios, it's already in the right format
+ if isinstance(audio, (list, tuple)) and is_valid_list_of_audio(audio):
+ return audio
+
+ # If it's a single audio, convert it to a list of
+ if is_valid_audio(audio):
+ return [audio]
+
+ raise ValueError("Invalid input type. Must be a single audio or a list of audio")
+
+
+def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
+ """
+ Convert frequency from hertz to mels.
+
+ Args:
+ freq (`float` or `np.ndarray`):
+ The frequency, or multiple frequencies, in hertz (Hz).
+ mel_scale (`str`, *optional*, defaults to `"htk"`):
+ The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
+
+ Returns:
+ `float` or `np.ndarray`: The frequencies on the mel scale.
+ """
+
+ if mel_scale not in ["slaney", "htk", "kaldi"]:
+ raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
+
+ if mel_scale == "htk":
+ return 2595.0 * np.log10(1.0 + (freq / 700.0))
+ elif mel_scale == "kaldi":
+ return 1127.0 * np.log(1.0 + (freq / 700.0))
+
+ min_log_hertz = 1000.0
+ min_log_mel = 15.0
+ logstep = 27.0 / np.log(6.4)
+ mels = 3.0 * freq / 200.0
+
+ if isinstance(freq, np.ndarray):
+ log_region = freq >= min_log_hertz
+ mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep
+ elif freq >= min_log_hertz:
+ mels = min_log_mel + np.log(freq / min_log_hertz) * logstep
+
+ return mels
+
+
+def mel_to_hertz(mels: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
+ """
+ Convert frequency from mels to hertz.
+
+ Args:
+ mels (`float` or `np.ndarray`):
+ The frequency, or multiple frequencies, in mels.
+ mel_scale (`str`, *optional*, `"htk"`):
+ The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
+
+ Returns:
+ `float` or `np.ndarray`: The frequencies in hertz.
+ """
+
+ if mel_scale not in ["slaney", "htk", "kaldi"]:
+ raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
+
+ if mel_scale == "htk":
+ return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
+ elif mel_scale == "kaldi":
+ return 700.0 * (np.exp(mels / 1127.0) - 1.0)
+
+ min_log_hertz = 1000.0
+ min_log_mel = 15.0
+ logstep = np.log(6.4) / 27.0
+ freq = 200.0 * mels / 3.0
+
+ if isinstance(mels, np.ndarray):
+ log_region = mels >= min_log_mel
+ freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel))
+ elif mels >= min_log_mel:
+ freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel))
+
+ return freq
+
+
+def hertz_to_octave(freq: Union[float, np.ndarray], tuning: float = 0.0, bins_per_octave: int = 12):
+ """
+ Convert frequency from hertz to fractional octave numbers.
+ Adapted from *librosa*.
+
+ Args:
+ freq (`float` or `np.ndarray`):
+ The frequency, or multiple frequencies, in hertz (Hz).
+ tuning (`float`, defaults to `0.`):
+ Tuning deviation from the Stuttgart pitch (A440) in (fractional) bins per octave.
+ bins_per_octave (`int`, defaults to `12`):
+ Number of bins per octave.
+
+ Returns:
+ `float` or `np.ndarray`: The frequencies on the octave scale.
+ """
+ stuttgart_pitch = 440.0 * 2.0 ** (tuning / bins_per_octave)
+ octave = np.log2(freq / (float(stuttgart_pitch) / 16))
+ return octave
+
+
+def _create_triangular_filter_bank(fft_freqs: np.ndarray, filter_freqs: np.ndarray) -> np.ndarray:
+ """
+ Creates a triangular filter bank.
+
+ Adapted from *torchaudio* and *librosa*.
+
+ Args:
+ fft_freqs (`np.ndarray` of shape `(num_frequency_bins,)`):
+ Discrete frequencies of the FFT bins in Hz.
+ filter_freqs (`np.ndarray` of shape `(num_mel_filters,)`):
+ Center frequencies of the triangular filters to create, in Hz.
+
+ Returns:
+ `np.ndarray` of shape `(num_frequency_bins, num_mel_filters)`
+ """
+ filter_diff = np.diff(filter_freqs)
+ slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
+ down_slopes = -slopes[:, :-2] / filter_diff[:-1]
+ up_slopes = slopes[:, 2:] / filter_diff[1:]
+ return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
+
+
+def chroma_filter_bank(
+ num_frequency_bins: int,
+ num_chroma: int,
+ sampling_rate: int,
+ tuning: float = 0.0,
+ power: Optional[float] = 2.0,
+ weighting_parameters: Optional[tuple[float, float]] = (5.0, 2.0),
+ start_at_c_chroma: bool = True,
+):
+ """
+ Creates a chroma filter bank, i.e a linear transformation to project spectrogram bins onto chroma bins.
+
+ Adapted from *librosa*.
+
+ Args:
+ num_frequency_bins (`int`):
+ Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
+ num_chroma (`int`):
+ Number of chroma bins (i.e pitch classes).
+ sampling_rate (`float`):
+ Sample rate of the audio waveform.
+ tuning (`float`):
+ Tuning deviation from A440 in fractions of a chroma bin.
+ power (`float`, *optional*, defaults to 2.0):
+ If 12.0, normalizes each column with their L2 norm. If 1.0, normalizes each column with their L1 norm.
+ weighting_parameters (`tuple[float, float]`, *optional*, defaults to `(5., 2.)`):
+ If specified, apply a Gaussian weighting parameterized by the first element of the tuple being the center and
+ the second element being the Gaussian half-width.
+ start_at_c_chroma (`bool`, *optional*, defaults to `True`):
+ If True, the filter bank will start at the 'C' pitch class. Otherwise, it will start at 'A'.
+ Returns:
+ `np.ndarray` of shape `(num_frequency_bins, num_chroma)`
+ """
+ # Get the FFT bins, not counting the DC component
+ frequencies = np.linspace(0, sampling_rate, num_frequency_bins, endpoint=False)[1:]
+
+ freq_bins = num_chroma * hertz_to_octave(frequencies, tuning=tuning, bins_per_octave=num_chroma)
+
+ # make up a value for the 0 Hz bin = 1.5 octaves below bin 1
+ # (so chroma is 50% rotated from bin 1, and bin width is broad)
+ freq_bins = np.concatenate(([freq_bins[0] - 1.5 * num_chroma], freq_bins))
+
+ bins_width = np.concatenate((np.maximum(freq_bins[1:] - freq_bins[:-1], 1.0), [1]))
+
+ chroma_filters = np.subtract.outer(freq_bins, np.arange(0, num_chroma, dtype="d")).T
+
+ num_chroma2 = np.round(float(num_chroma) / 2)
+
+ # Project into range -num_chroma/2 .. num_chroma/2
+ # add on fixed offset of 10*num_chroma to ensure all values passed to
+ # rem are positive
+ chroma_filters = np.remainder(chroma_filters + num_chroma2 + 10 * num_chroma, num_chroma) - num_chroma2
+
+ # Gaussian bumps - 2*D to make them narrower
+ chroma_filters = np.exp(-0.5 * (2 * chroma_filters / np.tile(bins_width, (num_chroma, 1))) ** 2)
+
+ # normalize each column
+ if power is not None:
+ chroma_filters = chroma_filters / np.sum(chroma_filters**power, axis=0, keepdims=True) ** (1.0 / power)
+
+ # Maybe apply scaling for fft bins
+ if weighting_parameters is not None:
+ center, half_width = weighting_parameters
+ chroma_filters *= np.tile(
+ np.exp(-0.5 * (((freq_bins / num_chroma - center) / half_width) ** 2)),
+ (num_chroma, 1),
+ )
+
+ if start_at_c_chroma:
+ chroma_filters = np.roll(chroma_filters, -3 * (num_chroma // 12), axis=0)
+
+ # remove aliasing columns, copy to ensure row-contiguity
+ return np.ascontiguousarray(chroma_filters[:, : int(1 + num_frequency_bins / 2)])
+
+
+def mel_filter_bank(
+ num_frequency_bins: int,
+ num_mel_filters: int,
+ min_frequency: float,
+ max_frequency: float,
+ sampling_rate: int,
+ norm: Optional[str] = None,
+ mel_scale: str = "htk",
+ triangularize_in_mel_space: bool = False,
+) -> np.ndarray:
+ """
+ Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and
+ various implementation exist, which differ in the number of filters, the shape of the filters, the way the filters
+ are spaced, the bandwidth of the filters, and the manner in which the spectrum is warped. The goal of these
+ features is to approximate the non-linear human perception of the variation in pitch with respect to the frequency.
+
+ Different banks of mel filters were introduced in the literature. The following variations are supported:
+
+ - MFCC FB-20: introduced in 1980 by Davis and Mermelstein, it assumes a sampling frequency of 10 kHz and a speech
+ bandwidth of `[0, 4600]` Hz.
+ - MFCC FB-24 HTK: from the Cambridge HMM Toolkit (HTK) (1995) uses a filter bank of 24 filters for a speech
+ bandwidth of `[0, 8000]` Hz. This assumes sampling rate ≥ 16 kHz.
+ - MFCC FB-40: from the Auditory Toolbox for MATLAB written by Slaney in 1998, assumes a sampling rate of 16 kHz and
+ speech bandwidth of `[133, 6854]` Hz. This version also includes area normalization.
+ - HFCC-E FB-29 (Human Factor Cepstral Coefficients) of Skowronski and Harris (2004), assumes a sampling rate of
+ 12.5 kHz and speech bandwidth of `[0, 6250]` Hz.
+
+ This code is adapted from *torchaudio* and *librosa*. Note that the default parameters of torchaudio's
+ `melscale_fbanks` implement the `"htk"` filters while librosa uses the `"slaney"` implementation.
+
+ Args:
+ num_frequency_bins (`int`):
+ Number of frequency bins (should be the same as `n_fft // 2 + 1` where `n_fft` is the size of the Fourier Transform used to compute the spectrogram).
+ num_mel_filters (`int`):
+ Number of mel filters to generate.
+ min_frequency (`float`):
+ Lowest frequency of interest in Hz.
+ max_frequency (`float`):
+ Highest frequency of interest in Hz. This should not exceed `sampling_rate / 2`.
+ sampling_rate (`int`):
+ Sample rate of the audio waveform.
+ norm (`str`, *optional*):
+ If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization).
+ mel_scale (`str`, *optional*, defaults to `"htk"`):
+ The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
+ triangularize_in_mel_space (`bool`, *optional*, defaults to `False`):
+ If this option is enabled, the triangular filter is applied in mel space rather than frequency space. This
+ should be set to `true` in order to get the same results as `torchaudio` when computing mel filters.
+
+ Returns:
+ `np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a
+ projection matrix to go from a spectrogram to a mel spectrogram.
+ """
+ if norm is not None and norm != "slaney":
+ raise ValueError('norm must be one of None or "slaney"')
+
+ if num_frequency_bins < 2:
+ raise ValueError(f"Require num_frequency_bins: {num_frequency_bins} >= 2")
+
+ if min_frequency > max_frequency:
+ raise ValueError(f"Require min_frequency: {min_frequency} <= max_frequency: {max_frequency}")
+
+ # center points of the triangular mel filters
+ mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
+ mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
+ mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
+ filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)
+
+ if triangularize_in_mel_space:
+ # frequencies of FFT bins in Hz, but filters triangularized in mel space
+ fft_bin_width = sampling_rate / ((num_frequency_bins - 1) * 2)
+ fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale)
+ filter_freqs = mel_freqs
+ else:
+ # frequencies of FFT bins in Hz
+ fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
+
+ mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
+
+ if norm is not None and norm == "slaney":
+ # Slaney-style mel is scaled to be approx constant energy per channel
+ enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters])
+ mel_filters *= np.expand_dims(enorm, 0)
+
+ if (mel_filters.max(axis=0) == 0.0).any():
+ warnings.warn(
+ "At least one mel filter has all zero values. "
+ f"The value for `num_mel_filters` ({num_mel_filters}) may be set too high. "
+ f"Or, the value for `num_frequency_bins` ({num_frequency_bins}) may be set too low."
+ )
+
+ return mel_filters
+
+
+def optimal_fft_length(window_length: int) -> int:
+ """
+ Finds the best FFT input size for a given `window_length`. This function takes a given window length and, if not
+ already a power of two, rounds it up to the next power or two.
+
+ The FFT algorithm works fastest when the length of the input is a power of two, which may be larger than the size
+ of the window or analysis frame. For example, if the window is 400 samples, using an FFT input size of 512 samples
+ is more optimal than an FFT size of 400 samples. Using a larger FFT size does not affect the detected frequencies,
+ it simply gives a higher frequency resolution (i.e. the frequency bins are smaller).
+ """
+ return 2 ** int(np.ceil(np.log2(window_length)))
+
+
+def window_function(
+ window_length: int,
+ name: str = "hann",
+ periodic: bool = True,
+ frame_length: Optional[int] = None,
+ center: bool = True,
+) -> np.ndarray:
+ """
+ Returns an array containing the specified window. This window is intended to be used with `stft`.
+
+ The following window types are supported:
+
+ - `"boxcar"`: a rectangular window
+ - `"hamming"`: the Hamming window
+ - `"hann"`: the Hann window
+ - `"povey"`: the Povey window
+
+ Args:
+ window_length (`int`):
+ The length of the window in samples.
+ name (`str`, *optional*, defaults to `"hann"`):
+ The name of the window function.
+ periodic (`bool`, *optional*, defaults to `True`):
+ Whether the window is periodic or symmetric.
+ frame_length (`int`, *optional*):
+ The length of the analysis frames in samples. Provide a value for `frame_length` if the window is smaller
+ than the frame length, so that it will be zero-padded.
+ center (`bool`, *optional*, defaults to `True`):
+ Whether to center the window inside the FFT buffer. Only used when `frame_length` is provided.
+
+ Returns:
+ `np.ndarray` of shape `(window_length,)` or `(frame_length,)` containing the window.
+ """
+ length = window_length + 1 if periodic else window_length
+
+ if name == "boxcar":
+ window = np.ones(length)
+ elif name in ["hamming", "hamming_window"]:
+ window = np.hamming(length)
+ elif name in ["hann", "hann_window"]:
+ window = np.hanning(length)
+ elif name == "povey":
+ window = np.power(np.hanning(length), 0.85)
+ else:
+ raise ValueError(f"Unknown window function '{name}'")
+
+ if periodic:
+ window = window[:-1]
+
+ if frame_length is None:
+ return window
+
+ if window_length > frame_length:
+ raise ValueError(
+ f"Length of the window ({window_length}) may not be larger than frame_length ({frame_length})"
+ )
+
+ padded_window = np.zeros(frame_length)
+ offset = (frame_length - window_length) // 2 if center else 0
+ padded_window[offset : offset + window_length] = window
+ return padded_window
+
+
+# TODO This method does not support batching yet as we are mainly focused on inference.
+def spectrogram(
+ waveform: np.ndarray,
+ window: np.ndarray,
+ frame_length: int,
+ hop_length: int,
+ fft_length: Optional[int] = None,
+ power: Optional[float] = 1.0,
+ center: bool = True,
+ pad_mode: str = "reflect",
+ onesided: bool = True,
+ dither: float = 0.0,
+ preemphasis: Optional[float] = None,
+ mel_filters: Optional[np.ndarray] = None,
+ mel_floor: float = 1e-10,
+ log_mel: Optional[str] = None,
+ reference: float = 1.0,
+ min_value: float = 1e-10,
+ db_range: Optional[float] = None,
+ remove_dc_offset: bool = False,
+ dtype: np.dtype = np.float32,
+) -> np.ndarray:
+ """
+ Calculates a spectrogram over one waveform using the Short-Time Fourier Transform.
+
+ This function can create the following kinds of spectrograms:
+
+ - amplitude spectrogram (`power = 1.0`)
+ - power spectrogram (`power = 2.0`)
+ - complex-valued spectrogram (`power = None`)
+ - log spectrogram (use `log_mel` argument)
+ - mel spectrogram (provide `mel_filters`)
+ - log-mel spectrogram (provide `mel_filters` and `log_mel`)
+
+ How this works:
+
+ 1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
+ - hop_length` samples.
+ 2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
+ 3. The DFT is taken of each windowed frame.
+ 4. The results are stacked into a spectrogram.
+
+ We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
+
+ - The analysis frame. This is the size of the time slices that the input waveform is split into.
+ - The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
+ - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
+
+ In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
+ padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
+ typically the next power of two.
+
+ Note: This function is not optimized for speed yet. It should be mostly compatible with `librosa.stft` and
+ `torchaudio.functional.transforms.Spectrogram`, although it is more flexible due to the different ways spectrograms
+ can be constructed.
+
+ Args:
+ waveform (`np.ndarray` of shape `(length,)`):
+ The input waveform. This must be a single real-valued, mono waveform.
+ window (`np.ndarray` of shape `(frame_length,)`):
+ The windowing function to apply, including zero-padding if necessary. The actual window length may be
+ shorter than `frame_length`, but we're assuming the array has already been zero-padded.
+ frame_length (`int`):
+ The length of the analysis frames in samples. With librosa this is always equal to `fft_length` but we also
+ allow smaller sizes.
+ hop_length (`int`):
+ The stride between successive analysis frames in samples.
+ fft_length (`int`, *optional*):
+ The size of the FFT buffer in samples. This determines how many frequency bins the spectrogram will have.
+ For optimal speed, this should be a power of two. If `None`, uses `frame_length`.
+ power (`float`, *optional*, defaults to 1.0):
+ If 1.0, returns the amplitude spectrogram. If 2.0, returns the power spectrogram. If `None`, returns
+ complex numbers.
+ center (`bool`, *optional*, defaults to `True`):
+ Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `False`, frame
+ `t` will start at time `t * hop_length`.
+ pad_mode (`str`, *optional*, defaults to `"reflect"`):
+ Padding mode used when `center` is `True`. Possible values are: `"constant"` (pad with zeros), `"edge"`
+ (pad with edge values), `"reflect"` (pads with mirrored values).
+ onesided (`bool`, *optional*, defaults to `True`):
+ If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
+ frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins.
+ dither (`float`, *optional*, defaults to 0.0):
+ Adds dithering. In other words, adds a small Gaussian noise to each frame.
+ E.g. use 4.0 to add dithering with a normal distribution centered
+ around 0.0 with standard deviation 4.0, 0.0 means no dithering.
+ Dithering has similar effect as `mel_floor`. It reduces the high log_mel_fbank
+ values for signals with hard-zero sections, when VAD cutoff is present in the signal.
+ preemphasis (`float`, *optional*)
+ Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
+ mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*):
+ The mel filter bank. If supplied, applies a this filter bank to create a mel spectrogram.
+ mel_floor (`float`, *optional*, defaults to 1e-10):
+ Minimum value of mel frequency banks.
+ log_mel (`str`, *optional*):
+ How to convert the spectrogram to log scale. Possible options are: `None` (don't convert), `"log"` (take
+ the natural logarithm) `"log10"` (take the base-10 logarithm), `"dB"` (convert to decibels). Can only be
+ used when `power` is not `None`.
+ reference (`float`, *optional*, defaults to 1.0):
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
+ the loudest part to 0 dB. Must be greater than zero.
+ min_value (`float`, *optional*, defaults to `1e-10`):
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
+ `log(0)`. For a power spectrogram, the default of `1e-10` corresponds to a minimum of -100 dB. For an
+ amplitude spectrogram, the value `1e-5` corresponds to -100 dB. Must be greater than zero.
+ db_range (`float`, *optional*):
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
+ remove_dc_offset (`bool`, *optional*):
+ Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in
+ order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters.
+ dtype (`np.dtype`, *optional*, defaults to `np.float32`):
+ Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be
+ `np.complex64`.
+
+ Returns:
+ `nd.array` containing a spectrogram of shape `(num_frequency_bins, length)` for a regular spectrogram or shape
+ `(num_mel_filters, length)` for a mel spectrogram.
+ """
+ window_length = len(window)
+
+ if fft_length is None:
+ fft_length = frame_length
+
+ if frame_length > fft_length:
+ raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
+
+ if window_length != frame_length:
+ raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
+
+ if hop_length <= 0:
+ raise ValueError("hop_length must be greater than zero")
+
+ if waveform.ndim != 1:
+ raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
+
+ if np.iscomplexobj(waveform):
+ raise ValueError("Complex-valued input waveforms are not currently supported")
+
+ if power is None and mel_filters is not None:
+ raise ValueError(
+ "You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram."
+ "Specify `power` to fix this issue."
+ )
+
+ # center pad the waveform
+ if center:
+ padding = [(int(frame_length // 2), int(frame_length // 2))]
+ waveform = np.pad(waveform, padding, mode=pad_mode)
+
+ # promote to float64, since np.fft uses float64 internally
+ waveform = waveform.astype(np.float64)
+ window = window.astype(np.float64)
+
+ # split waveform into frames of frame_length size
+ num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))
+
+ num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
+ spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)
+
+ # rfft is faster than fft
+ fft_func = np.fft.rfft if onesided else np.fft.fft
+ buffer = np.zeros(fft_length)
+
+ timestep = 0
+ for frame_idx in range(num_frames):
+ buffer[:frame_length] = waveform[timestep : timestep + frame_length]
+
+ if dither != 0.0:
+ buffer[:frame_length] += dither * np.random.randn(frame_length)
+
+ if remove_dc_offset:
+ buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()
+
+ if preemphasis is not None:
+ buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]
+ buffer[0] *= 1 - preemphasis
+
+ buffer[:frame_length] *= window
+
+ spectrogram[frame_idx] = fft_func(buffer)
+ timestep += hop_length
+
+ # note: ** is much faster than np.power
+ if power is not None:
+ spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
+
+ spectrogram = spectrogram.T
+
+ if mel_filters is not None:
+ spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))
+
+ if power is not None and log_mel is not None:
+ if log_mel == "log":
+ spectrogram = np.log(spectrogram)
+ elif log_mel == "log10":
+ spectrogram = np.log10(spectrogram)
+ elif log_mel == "dB":
+ if power == 1.0:
+ spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)
+ elif power == 2.0:
+ spectrogram = power_to_db(spectrogram, reference, min_value, db_range)
+ else:
+ raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
+ else:
+ raise ValueError(f"Unknown log_mel option: {log_mel}")
+
+ spectrogram = np.asarray(spectrogram, dtype)
+
+ return spectrogram
+
+
+def spectrogram_batch(
+ waveform_list: list[np.ndarray],
+ window: np.ndarray,
+ frame_length: int,
+ hop_length: int,
+ fft_length: Optional[int] = None,
+ power: Optional[float] = 1.0,
+ center: bool = True,
+ pad_mode: str = "reflect",
+ onesided: bool = True,
+ dither: float = 0.0,
+ preemphasis: Optional[float] = None,
+ mel_filters: Optional[np.ndarray] = None,
+ mel_floor: float = 1e-10,
+ log_mel: Optional[str] = None,
+ reference: float = 1.0,
+ min_value: float = 1e-10,
+ db_range: Optional[float] = None,
+ remove_dc_offset: bool = False,
+ dtype: np.dtype = np.float32,
+) -> list[np.ndarray]:
+ """
+ Calculates spectrograms for a list of waveforms using the Short-Time Fourier Transform, optimized for batch processing.
+ This function extends the capabilities of the `spectrogram` function to handle multiple waveforms efficiently by leveraging broadcasting.
+
+ It supports generating various types of spectrograms:
+
+ - amplitude spectrogram (`power = 1.0`)
+ - power spectrogram (`power = 2.0`)
+ - complex-valued spectrogram (`power = None`)
+ - log spectrogram (use `log_mel` argument)
+ - mel spectrogram (provide `mel_filters`)
+ - log-mel spectrogram (provide `mel_filters` and `log_mel`)
+
+ How this works:
+
+ 1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
+ - hop_length` samples.
+ 2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
+ 3. The DFT is taken of each windowed frame.
+ 4. The results are stacked into a spectrogram.
+
+ We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
+
+ - The analysis frame. This is the size of the time slices that the input waveform is split into.
+ - The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
+ - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
+
+ In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
+ padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
+ typically the next power of two.
+
+ Note: This function is designed for efficient batch processing of multiple waveforms but retains compatibility with individual waveform processing methods like `librosa.stft`.
+
+ Args:
+ waveform_list (`list[np.ndarray]` with arrays of shape `(length,)`):
+ The list of input waveforms, each a single-channel (mono) signal.
+ window (`np.ndarray` of shape `(frame_length,)`):
+ The windowing function to apply, including zero-padding if necessary.
+ frame_length (`int`):
+ The length of each frame for analysis.
+ hop_length (`int`):
+ The step size between successive frames.
+ fft_length (`int`, *optional*):
+ The size of the FFT buffer, defining frequency bin resolution.
+ power (`float`, *optional*, defaults to 1.0):
+ Determines the type of spectrogram: 1.0 for amplitude, 2.0 for power, None for complex.
+ center (`bool`, *optional*, defaults to `True`):
+ Whether to center-pad the waveform frames.
+ pad_mode (`str`, *optional*, defaults to `"reflect"`):
+ The padding strategy when `center` is `True`.
+ onesided (`bool`, *optional*, defaults to `True`):
+ If True, returns a one-sided spectrogram for real input signals.
+ dither (`float`, *optional*, defaults to 0.0):
+ Adds dithering. In other words, adds a small Gaussian noise to each frame.
+ E.g. use 4.0 to add dithering with a normal distribution centered
+ around 0.0 with standard deviation 4.0, 0.0 means no dithering.
+ preemphasis (`float`, *optional*):
+ Applies a pre-emphasis filter to each frame.
+ mel_filters (`np.ndarray`, *optional*):
+ Mel filter bank for converting to mel spectrogram.
+ mel_floor (`float`, *optional*, defaults to 1e-10):
+ Floor value for mel spectrogram to avoid log(0).
+ log_mel (`str`, *optional*):
+ Specifies log scaling strategy; options are None, "log", "log10", "dB".
+ reference (`float`, *optional*, defaults to 1.0):
+ Reference value for dB conversion in log_mel.
+ min_value (`float`, *optional*, defaults to 1e-10):
+ Minimum floor value for log scale conversions.
+ db_range (`float`, *optional*):
+ Dynamic range for dB scale spectrograms.
+ remove_dc_offset (`bool`, *optional*):
+ Whether to remove the DC offset from each frame.
+ dtype (`np.dtype`, *optional*, defaults to `np.float32`):
+ Data type of the output spectrogram.
+
+ Returns:
+ list[`np.ndarray`]: A list of spectrogram arrays, one for each input waveform.
+ """
+ window_length = len(window)
+
+ if fft_length is None:
+ fft_length = frame_length
+
+ if frame_length > fft_length:
+ raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
+
+ if window_length != frame_length:
+ raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
+
+ if hop_length <= 0:
+ raise ValueError("hop_length must be greater than zero")
+
+ # Check the dimensions of the waveform , and if waveform is complex
+ for waveform in waveform_list:
+ if waveform.ndim != 1:
+ raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
+ if np.iscomplexobj(waveform):
+ raise ValueError("Complex-valued input waveforms are not currently supported")
+ # Center pad the waveform
+ if center:
+ padding = [(int(frame_length // 2), int(frame_length // 2))]
+ waveform_list = [
+ np.pad(
+ waveform,
+ padding,
+ mode=pad_mode,
+ )
+ for waveform in waveform_list
+ ]
+ original_waveform_lengths = [
+ len(waveform) for waveform in waveform_list
+ ] # these lengths will be used to remove padding later
+
+ # Batch pad the waveform
+ max_length = max(original_waveform_lengths)
+ padded_waveform_batch = np.array(
+ [
+ np.pad(waveform, (0, max_length - len(waveform)), mode="constant", constant_values=0)
+ for waveform in waveform_list
+ ],
+ dtype=dtype,
+ )
+
+ # Promote to float64, since np.fft uses float64 internally
+ padded_waveform_batch = padded_waveform_batch.astype(np.float64)
+ window = window.astype(np.float64)
+
+ # Split waveform into frames of frame_length size
+ num_frames = int(1 + np.floor((padded_waveform_batch.shape[1] - frame_length) / hop_length))
+ # these lengths will be used to remove padding later
+ true_num_frames = [int(1 + np.floor((length - frame_length) / hop_length)) for length in original_waveform_lengths]
+ num_batches = padded_waveform_batch.shape[0]
+
+ num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
+ spectrogram = np.empty((num_batches, num_frames, num_frequency_bins), dtype=np.complex64)
+
+ # rfft is faster than fft
+ fft_func = np.fft.rfft if onesided else np.fft.fft
+ buffer = np.zeros((num_batches, fft_length))
+
+ for frame_idx in range(num_frames):
+ timestep = frame_idx * hop_length
+ buffer[:, :frame_length] = padded_waveform_batch[:, timestep : timestep + frame_length]
+
+ if dither != 0.0:
+ buffer[:, :frame_length] += dither * np.random.randn(*buffer[:, :frame_length].shape)
+
+ if remove_dc_offset:
+ buffer[:, :frame_length] -= buffer[:, :frame_length].mean(axis=1, keepdims=True)
+
+ if preemphasis is not None:
+ buffer[:, 1:frame_length] -= preemphasis * buffer[:, : frame_length - 1]
+ buffer[:, 0] *= 1 - preemphasis
+
+ buffer[:, :frame_length] *= window
+
+ spectrogram[:, frame_idx] = fft_func(buffer)
+
+ # Note: ** is much faster than np.power
+ if power is not None:
+ spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
+
+ # Apply mel filters if provided
+ if mel_filters is not None:
+ result = np.tensordot(spectrogram, mel_filters.T, axes=([2], [1]))
+ spectrogram = np.maximum(mel_floor, result)
+
+ # Convert to log scale if specified
+ if power is not None and log_mel is not None:
+ if log_mel == "log":
+ spectrogram = np.log(spectrogram)
+ elif log_mel == "log10":
+ spectrogram = np.log10(spectrogram)
+ elif log_mel == "dB":
+ if power == 1.0:
+ spectrogram = amplitude_to_db_batch(spectrogram, reference, min_value, db_range)
+ elif power == 2.0:
+ spectrogram = power_to_db_batch(spectrogram, reference, min_value, db_range)
+ else:
+ raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
+ else:
+ raise ValueError(f"Unknown log_mel option: {log_mel}")
+
+ spectrogram = np.asarray(spectrogram, dtype)
+
+ spectrogram_list = [spectrogram[i, : true_num_frames[i], :].T for i in range(len(true_num_frames))]
+
+ return spectrogram_list
+
+
+def power_to_db(
+ spectrogram: np.ndarray,
+ reference: float = 1.0,
+ min_value: float = 1e-10,
+ db_range: Optional[float] = None,
+) -> np.ndarray:
+ """
+ Converts a power spectrogram to the decibel scale. This computes `10 * log10(spectrogram / reference)`, using basic
+ logarithm properties for numerical stability.
+
+ The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
+ linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
+ This means that large variations in energy may not sound all that different if the sound is loud to begin with.
+ This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
+
+ Based on the implementation of `librosa.power_to_db`.
+
+ Args:
+ spectrogram (`np.ndarray`):
+ The input power (mel) spectrogram. Note that a power spectrogram has the amplitudes squared!
+ reference (`float`, *optional*, defaults to 1.0):
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
+ the loudest part to 0 dB. Must be greater than zero.
+ min_value (`float`, *optional*, defaults to `1e-10`):
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
+ `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
+ db_range (`float`, *optional*):
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
+
+ Returns:
+ `np.ndarray`: the spectrogram in decibels
+ """
+ if reference <= 0.0:
+ raise ValueError("reference must be greater than zero")
+ if min_value <= 0.0:
+ raise ValueError("min_value must be greater than zero")
+
+ reference = max(min_value, reference)
+
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
+ spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
+
+ if db_range is not None:
+ if db_range <= 0.0:
+ raise ValueError("db_range must be greater than zero")
+ spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
+
+ return spectrogram
+
+
+def power_to_db_batch(
+ spectrogram: np.ndarray,
+ reference: float = 1.0,
+ min_value: float = 1e-10,
+ db_range: Optional[float] = None,
+) -> np.ndarray:
+ """
+ Converts a batch of power spectrograms to the decibel scale. This computes `10 * log10(spectrogram / reference)`,
+ using basic logarithm properties for numerical stability.
+
+ This function supports batch processing, where each item in the batch is an individual power (mel) spectrogram.
+
+ Args:
+ spectrogram (`np.ndarray`):
+ The input batch of power (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape).
+ Note that a power spectrogram has the amplitudes squared!
+ reference (`float`, *optional*, defaults to 1.0):
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
+ the loudest part to 0 dB. Must be greater than zero.
+ min_value (`float`, *optional*, defaults to `1e-10`):
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
+ `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
+ db_range (`float`, *optional*):
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
+
+ Returns:
+ `np.ndarray`: the batch of spectrograms in decibels
+ """
+ if reference <= 0.0:
+ raise ValueError("reference must be greater than zero")
+ if min_value <= 0.0:
+ raise ValueError("min_value must be greater than zero")
+
+ reference = max(min_value, reference)
+
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
+ spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
+
+ if db_range is not None:
+ if db_range <= 0.0:
+ raise ValueError("db_range must be greater than zero")
+ # Apply db_range clipping per batch item
+ max_values = spectrogram.max(axis=(1, 2), keepdims=True)
+ spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
+
+ return spectrogram
+
+
+def amplitude_to_db(
+ spectrogram: np.ndarray,
+ reference: float = 1.0,
+ min_value: float = 1e-5,
+ db_range: Optional[float] = None,
+) -> np.ndarray:
+ """
+ Converts an amplitude spectrogram to the decibel scale. This computes `20 * log10(spectrogram / reference)`, using
+ basic logarithm properties for numerical stability.
+
+ The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
+ linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
+ This means that large variations in energy may not sound all that different if the sound is loud to begin with.
+ This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
+
+ Args:
+ spectrogram (`np.ndarray`):
+ The input amplitude (mel) spectrogram.
+ reference (`float`, *optional*, defaults to 1.0):
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
+ the loudest part to 0 dB. Must be greater than zero.
+ min_value (`float`, *optional*, defaults to `1e-5`):
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
+ `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
+ db_range (`float`, *optional*):
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
+
+ Returns:
+ `np.ndarray`: the spectrogram in decibels
+ """
+ if reference <= 0.0:
+ raise ValueError("reference must be greater than zero")
+ if min_value <= 0.0:
+ raise ValueError("min_value must be greater than zero")
+
+ reference = max(min_value, reference)
+
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
+ spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
+
+ if db_range is not None:
+ if db_range <= 0.0:
+ raise ValueError("db_range must be greater than zero")
+ spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
+
+ return spectrogram
+
+
+def amplitude_to_db_batch(
+ spectrogram: np.ndarray, reference: float = 1.0, min_value: float = 1e-5, db_range: Optional[float] = None
+) -> np.ndarray:
+ """
+ Converts a batch of amplitude spectrograms to the decibel scale. This computes `20 * log10(spectrogram / reference)`,
+ using basic logarithm properties for numerical stability.
+
+ The function supports batch processing, where each item in the batch is an individual amplitude (mel) spectrogram.
+
+ Args:
+ spectrogram (`np.ndarray`):
+ The input batch of amplitude (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape).
+ reference (`float`, *optional*, defaults to 1.0):
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
+ the loudest part to 0 dB. Must be greater than zero.
+ min_value (`float`, *optional*, defaults to `1e-5`):
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
+ `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
+ db_range (`float`, *optional*):
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
+
+ Returns:
+ `np.ndarray`: the batch of spectrograms in decibels
+ """
+ if reference <= 0.0:
+ raise ValueError("reference must be greater than zero")
+ if min_value <= 0.0:
+ raise ValueError("min_value must be greater than zero")
+
+ reference = max(min_value, reference)
+
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
+ spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
+
+ if db_range is not None:
+ if db_range <= 0.0:
+ raise ValueError("db_range must be greater than zero")
+ # Apply db_range clipping per batch item
+ max_values = spectrogram.max(axis=(1, 2), keepdims=True)
+ spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
+
+ return spectrogram
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/cache_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/cache_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..99beb0b610a1f88dbd2dadf473f829c0be12d8cc
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/cache_utils.py
@@ -0,0 +1,1493 @@
+from abc import ABC, abstractmethod
+from collections.abc import Iterable
+from typing import Any, Optional
+
+import torch
+
+from .configuration_utils import PretrainedConfig
+from .utils import (
+ is_hqq_available,
+ is_quanto_greater,
+ is_torch_greater_or_equal,
+ is_torchdynamo_compiling,
+ logging,
+)
+
+
+if is_hqq_available():
+ from hqq.core.quantize import Quantizer as HQQQuantizer
+
+_is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True)
+
+
+logger = logging.get_logger(__name__)
+
+
+class CacheLayerMixin(ABC):
+ """Base, abstract class for a single layer's cache."""
+
+ is_compileable = False
+
+ def __init__(self):
+ self.keys: Optional[torch.Tensor] = None
+ self.values: Optional[torch.Tensor] = None
+ self.is_initialized = False
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}"
+
+ @abstractmethod
+ def lazy_initialization(self, key_states: torch.Tensor): ...
+
+ @abstractmethod
+ def update(
+ self, key_states: torch.Tensor, value_states: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None
+ ) -> tuple[torch.Tensor, torch.Tensor]: ...
+
+ @abstractmethod
+ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: ...
+
+ @abstractmethod
+ def get_seq_length(self) -> int: ...
+
+ @abstractmethod
+ def get_max_cache_shape(self) -> int: ...
+
+ def offload(self):
+ """Offload this layer's data to CPU device."""
+ if self.is_initialized:
+ self.keys = self.keys.to("cpu", non_blocking=True)
+ self.values = self.values.to("cpu", non_blocking=True)
+
+ def prefetch(self):
+ """In case of layer offloading, this allows to move the data back to the layer's device ahead of time."""
+ if self.is_initialized and self.keys.device != self.device:
+ self.keys = self.keys.to(self.device, non_blocking=True)
+ self.values = self.values.to(self.device, non_blocking=True)
+
+ def reset(self) -> None:
+ """Resets the cache values while preserving the objects"""
+ if self.is_initialized:
+ self.keys.zero_()
+ self.values.zero_()
+ # This attribute is set on several Layers
+ if hasattr(self, "cumulative_length"):
+ self.cumulative_length = 0
+
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
+ """Reorders this layer's cache for beam search."""
+ if self.get_seq_length() > 0:
+ self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
+ self.values = self.values.index_select(0, beam_idx.to(self.values.device))
+
+
+class DynamicLayer(CacheLayerMixin):
+ """
+ A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
+ It stores the key and value states as tensors of shape `[batch_size, num_heads, seq_len, head_dim]`.
+ """
+
+ is_sliding = False
+
+ def lazy_initialization(self, key_states: torch.Tensor):
+ self.dtype, self.device = key_states.dtype, key_states.device
+ self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
+ self.values = torch.tensor([], dtype=self.dtype, device=self.device)
+ self.is_initialized = True
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ cache_kwargs: Optional[dict[str, Any]] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Update the key and value caches in-place, and return the necessary keys and value states.
+
+ Args:
+ key_states (`torch.Tensor`): The new key states to cache.
+ value_states (`torch.Tensor`): The new value states to cache.
+ cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
+
+ Returns:
+ tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
+ """
+ # Lazy initialization
+ if not self.is_initialized:
+ self.lazy_initialization(key_states)
+
+ self.keys = torch.cat([self.keys, key_states], dim=-2)
+ self.values = torch.cat([self.values, value_states], dim=-2)
+ return self.keys, self.values
+
+ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
+ """Return the length and offset of the cache, used to generate the mask"""
+ kv_offset = 0
+ query_length = cache_position.shape[0]
+ kv_length = self.get_seq_length() + query_length
+ return kv_length, kv_offset
+
+ def get_seq_length(self) -> int:
+ """Returns the sequence length of the cached states."""
+ if not self.is_initialized or self.keys.numel() == 0:
+ return 0
+ return self.keys.shape[-2]
+
+ def get_max_cache_shape(self) -> int:
+ """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length."""
+ return -1
+
+ def crop(self, max_length: int) -> None:
+ """
+ Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be negative
+ to remove `max_length` tokens.
+ """
+ if max_length < 0:
+ max_length = self.get_seq_length() - abs(max_length)
+
+ if self.get_seq_length() <= max_length:
+ return
+
+ self.keys = self.keys[..., :max_length, :]
+ self.values = self.values[..., :max_length, :]
+
+ def batch_repeat_interleave(self, repeats: int) -> None:
+ """Repeat the cache `repeats` times in the batch dimension."""
+ if self.get_seq_length() > 0:
+ self.keys = self.keys.repeat_interleave(repeats, dim=0)
+ self.values = self.values.repeat_interleave(repeats, dim=0)
+
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
+ """Only keep the `indices` in the batch dimension of the cache."""
+ if self.get_seq_length() > 0:
+ self.keys = self.keys[indices, ...]
+ self.values = self.values[indices, ...]
+
+
+class DynamicSlidingWindowLayer(DynamicLayer):
+ """
+ A cache layer that grows dynamically as more tokens are generated, up until the sliding window size.
+ It stores the key and value states as tensors of shape `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
+ """
+
+ is_sliding = True
+
+ def __init__(self, sliding_window: int):
+ super().__init__()
+ self.sliding_window = sliding_window
+ self.cumulative_length = 0
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ cache_kwargs: Optional[dict[str, Any]] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Update the key and value caches in-place, and return the necessary keys and value states.
+
+ Args:
+ key_states (`torch.Tensor`): The new key states to cache.
+ value_states (`torch.Tensor`): The new value states to cache.
+ cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
+
+ Returns:
+ tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
+ """
+ # Lazy initialization
+ if not self.is_initialized:
+ self.lazy_initialization(key_states)
+
+ self.cumulative_length += key_states.shape[-2]
+
+ # Compute the full states
+ full_key_states = torch.cat([self.keys, key_states], dim=-2)
+ full_value_states = torch.cat([self.values, value_states], dim=-2)
+ # Only cache the last `self.sliding_window - 1` tokens (or all of them if lower than that)
+ self.keys = full_key_states[:, :, -self.sliding_window + 1 :, :]
+ self.values = full_value_states[:, :, -self.sliding_window + 1 :, :]
+
+ # Return the full states
+ return full_key_states, full_value_states
+
+ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
+ """Return the length and offset of the cache, used to generate the attention mask"""
+ query_length = cache_position.shape[0]
+ is_full = self.cumulative_length >= self.sliding_window
+
+ kv_offset = max(self.cumulative_length - self.sliding_window + 1, 0)
+ if is_full:
+ kv_length = self.sliding_window - 1 + query_length
+ else:
+ kv_length = self.cumulative_length + query_length
+
+ return kv_length, kv_offset
+
+ def get_seq_length(self) -> int:
+ """Returns the sequence length of the cached states."""
+ return self.cumulative_length
+
+ def get_max_cache_shape(self) -> int:
+ """Return the maximum cache shape of the cache"""
+ return self.sliding_window
+
+ def crop(self, max_length: int) -> None:
+ """
+ Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
+ negative to remove `max_length` tokens.
+ """
+ if self.get_seq_length() >= self.sliding_window:
+ raise ValueError(
+ "Cannot `crop` a `DynamicSlidingWindowLayer` after it has seen more tokens than its"
+ "sliding window (otherwise some states are lost)"
+ )
+ super().crop(max_length)
+ self.cumulative_length = self.keys.shape[-2]
+
+
+class StaticLayer(CacheLayerMixin):
+ """
+ A static cache layer that stores the key and value states as static tensors of shape `[batch_size, num_heads, max_cache_len), head_dim]`.
+ It lazily allocates its full backing tensors, and then mutates them in-place. Built for `torch.compile` support.
+
+ Args:
+ max_cache_len (`int`):
+ Maximum number of tokens that can be stored, used for tensor preallocation.
+ """
+
+ is_compileable = True
+ is_sliding = False
+
+ def __init__(self, max_cache_len: int):
+ super().__init__()
+ self.max_cache_len = max_cache_len
+
+ def lazy_initialization(self, key_states: torch.Tensor):
+ """
+ Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device,
+ num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving
+ devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well).
+
+ If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this
+ function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we
+ internally don't compile the prefill, this is guaranteed to have been called already when compiling.
+ If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache,
+ it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs,
+ i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
+ not be compiled anyway for performances!
+ """
+ self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape
+ self.dtype, self.device = key_states.dtype, key_states.device
+
+ self.keys = torch.zeros(
+ (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
+ dtype=self.dtype,
+ device=self.device,
+ )
+ self.values = torch.zeros(
+ (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
+ dtype=self.dtype,
+ device=self.device,
+ )
+ # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph
+ # breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case.
+ # As prefill should never be compiled, this is not an issue and it will still be run (except when users compile
+ # prefill explicitly, but this should be avoided!)
+ if not is_torchdynamo_compiling():
+ torch._dynamo.mark_static_address(self.keys)
+ torch._dynamo.mark_static_address(self.values)
+
+ self.is_initialized = True
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ cache_kwargs: Optional[dict[str, Any]] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Update the key and value caches in-place, and return the necessary keys and value states.
+
+ Args:
+ key_states (`torch.Tensor`): The new key states to cache.
+ value_states (`torch.Tensor`): The new value states to cache.
+ cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
+
+ Returns:
+ tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
+ """
+ # Lazy initialization
+ if not self.is_initialized:
+ self.lazy_initialization(key_states)
+
+ # Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
+ # in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
+ cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
+ cache_position = (
+ cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
+ )
+
+ # Update the cache
+ try:
+ self.keys.index_copy_(2, cache_position, key_states)
+ self.values.index_copy_(2, cache_position, value_states)
+ except NotImplementedError:
+ # Fallback for devices like MPS where index_copy_ might not be supported.
+ self.keys[:, :, cache_position] = key_states
+ self.values[:, :, cache_position] = value_states
+ return self.keys, self.values
+
+ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
+ """Return the length and offset of the cache, used to generate the attention mask"""
+ kv_offset = 0
+ kv_length = self.max_cache_len
+ return kv_length, kv_offset
+
+ def get_seq_length(self) -> int:
+ """Returns the sequence length of the cached states."""
+ # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
+ # limit the check to the first batch member and head dimension.
+ return (self.keys[0, 0].any(dim=-1)).sum() if self.is_initialized else 0
+
+ def get_max_cache_shape(self) -> int:
+ """Return the maximum cache shape of the cache"""
+ return self.max_cache_len
+
+
+class StaticSlidingWindowLayer(StaticLayer):
+ """
+ A static cache layer that stores the key and value states as static tensors of shape
+ `[batch_size, num_heads, min(max_cache_len, sliding_window), head_dim]`. It lazily allocates its full backing
+ tensors, and then mutates them in-place. Built for `torch.compile` support.
+
+ Args:
+ max_cache_len (`int`):
+ Maximum number of tokens that can be stored, used for tensor preallocation.
+ sliding_window (`int`):
+ The size of the sliding window.
+ """
+
+ is_sliding = True
+
+ def __init__(self, max_cache_len: int, sliding_window: int):
+ effective_max_cache_len = min(sliding_window, max_cache_len)
+ super().__init__(max_cache_len=effective_max_cache_len)
+ self.cumulative_length = 0
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ cache_kwargs: Optional[dict[str, Any]] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Update the key and value caches in-place, and return the necessary keys and value states.
+
+ Args:
+ key_states (`torch.Tensor`): The new key states to cache.
+ value_states (`torch.Tensor`): The new value states to cache.
+ cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
+
+ Returns:
+ tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
+ """
+ # Lazy initialization
+ if not self.is_initialized:
+ self.lazy_initialization(key_states)
+
+ # Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
+ # in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
+ cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
+ cache_position = (
+ cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
+ )
+
+ cumulative_length = self.cumulative_length
+ is_full = cumulative_length >= self.max_cache_len
+ # Update it now that we saved the value above
+ self.cumulative_length += key_states.shape[-2]
+
+ if is_full:
+ # In general, we should use a much simpler `cat` here as well, independently of the states size. However,
+ # dynamo is currently bugged when doing it - see https://github.com/pytorch/pytorch/issues/159855 for more details
+ if key_states.shape[-2] == 1:
+ # Roll all values to the left by 1 position
+ new_keys = self.keys.roll(-1, dims=-2)
+ new_values = self.values.roll(-1, dims=-2)
+ # Overwrite the last position with new states
+ # (note: very important to use a tensor to index here, see https://github.com/pytorch/pytorch/issues/159855)
+ index = torch.tensor([-1], dtype=int, device=self.device)
+ new_keys[:, :, index] = key_states
+ new_values[:, :, index] = value_states
+
+ # Copy back into `self` (do not just assign again) in order to keep the static dynamo address
+ self.keys.copy_(new_keys)
+ self.values.copy_(new_values)
+ # Very important to return the `self` tensors here, as they have the static dynamo address
+ return self.keys, self.values
+ # Already full but using more than 1 new token (e.g. prefill caching, chat continuation, etc...)
+ else:
+ full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2)
+ full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2)
+ # Not yet full, but becoming full on this update
+ elif cumulative_length + key_states.shape[2] > self.max_cache_len:
+ # Fast prefill path, no need to cat() in this case, as the cache is currently empty
+ if cumulative_length == 0:
+ full_key_states = key_states
+ full_value_states = value_states
+ else:
+ full_key_states = torch.cat((self.keys[:, :, :cumulative_length, :], key_states), dim=-2)
+ full_value_states = torch.cat((self.values[:, :, :cumulative_length, :], value_states), dim=-2)
+ else:
+ try:
+ self.keys.index_copy_(2, cache_position, key_states)
+ self.values.index_copy_(2, cache_position, value_states)
+ except NotImplementedError:
+ self.keys[:, :, cache_position] = key_states
+ self.values[:, :, cache_position] = value_states
+
+ # Very important to return the `self` tensors here, as they have the static dynamo address
+ return self.keys, self.values
+
+ # We only cache the last `sliding_window` tokens
+ self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :])
+ self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :])
+ # we should return the whole states instead of `self.keys/values` here, as otherwise we lose some context
+ return full_key_states, full_value_states
+
+ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
+ """Return the length and offset of the cache, used to generate the attention mask"""
+ query_length = cache_position.shape[0]
+ sliding_window = self.max_cache_len
+ is_full = self.cumulative_length >= self.max_cache_len
+
+ kv_offset = max(self.cumulative_length - sliding_window + 1, 0)
+ # The cache is already full
+ if is_full:
+ kv_length = sliding_window + query_length - 1
+ # Not yet full, but becoming full on this update
+ elif self.cumulative_length + query_length > sliding_window:
+ kv_length = self.cumulative_length + query_length
+ # Here the Cache is still smaller than the local size, but we return the local size as it's static
+ else:
+ kv_length = sliding_window
+
+ return kv_length, kv_offset
+
+ def get_seq_length(self) -> int:
+ """Returns the sequence length of the cached states."""
+ return self.cumulative_length
+
+
+class QuantizedLayer(DynamicLayer):
+ """
+ A quantized layer similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
+ It allows the model to generate longer sequence length without allocating too much memory for the key and value caches by
+ applying quantization.
+
+ The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length`
+ is set as a maximum capacity for the original precision cache. When the length goes beyond maximum capacity, the original
+ precision cache is discarded and moved into the quantized cache. The quantization is done per-channel with a set `q_group_size`
+ for both Keys and Values, in contrast to what was described in the paper.
+ """
+
+ def __init__(
+ self,
+ nbits: int = 4,
+ axis_key: int = 0,
+ axis_value: int = 0,
+ q_group_size: int = 64,
+ residual_length: int = 128,
+ ):
+ super().__init__()
+ self.nbits = nbits
+ self.axis_key = axis_key
+ self.axis_value = axis_value
+ self.q_group_size = q_group_size
+ self.residual_length = residual_length
+ self.cumulative_length = 0
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ cache_kwargs: Optional[dict[str, Any]] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Update the key and value caches in-place, and return the necessary keys and value states.
+
+ Args:
+ key_states (`torch.Tensor`): The new key states to cache.
+ value_states (`torch.Tensor`): The new value states to cache.
+ cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
+
+ Returns:
+ tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
+ """
+ self.cumulative_length += key_states.shape[-2]
+
+ # Lazy initialization
+ if not self.is_initialized:
+ self.lazy_initialization(key_states)
+ self._quantized_keys = self._quantize(key_states.contiguous(), axis=self.axis_key)
+ self._quantized_values = self._quantize(value_states.contiguous(), axis=self.axis_value)
+ return key_states, value_states
+
+ dequant_keys = self._dequantize(self._quantized_keys)
+ dequant_values = self._dequantize(self._quantized_values)
+ keys_to_return = torch.cat([dequant_keys, self.keys, key_states], dim=-2)
+ values_to_return = torch.cat([dequant_values, self.values, value_states], dim=-2)
+ if self.keys.dim() == 4 and self.keys.shape[-2] + 1 >= self.residual_length:
+ self._quantized_keys = self._quantize(keys_to_return.contiguous(), axis=self.axis_key)
+ self._quantized_values = self._quantize(values_to_return.contiguous(), axis=self.axis_value)
+ self.keys = torch.tensor([], dtype=key_states.dtype, device=key_states.device)
+ self.values = torch.tensor([], dtype=key_states.dtype, device=key_states.device)
+ else:
+ self.keys = torch.cat([self.keys, key_states], dim=-2)
+ self.values = torch.cat([self.values, value_states], dim=-2)
+
+ return keys_to_return, values_to_return
+
+ @abstractmethod
+ def _quantize(self, tensor, axis): ...
+
+ @abstractmethod
+ def _dequantize(self, q_tensor): ...
+
+ def get_seq_length(self) -> int:
+ """Returns the sequence length of the cached states."""
+ return self.cumulative_length
+
+
+class QuantoQuantizedLayer(QuantizedLayer):
+ def __init__(
+ self,
+ nbits: int = 4,
+ axis_key: int = 0,
+ axis_value: int = 0,
+ q_group_size: int = 64,
+ residual_length: int = 128,
+ ):
+ super().__init__(
+ nbits=nbits,
+ axis_key=axis_key,
+ axis_value=axis_value,
+ q_group_size=q_group_size,
+ residual_length=residual_length,
+ )
+
+ # We need to import quanto here to avoid circular imports due to optimum/quanto/models/transformers_models.py
+ if is_quanto_greater("0.2.5", accept_dev=True):
+ from optimum.quanto import MaxOptimizer, qint2, qint4
+ else:
+ raise ImportError(
+ "You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. "
+ )
+
+ if self.nbits not in [2, 4]:
+ raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
+
+ if self.axis_key not in [0, -1]:
+ raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}")
+
+ if self.axis_value not in [0, -1]:
+ raise ValueError(
+ f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}"
+ )
+
+ self.qtype = qint4 if self.nbits == 4 else qint2
+ self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization
+
+ def _quantize(self, tensor, axis):
+ from optimum.quanto import quantize_weight
+
+ scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
+ qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
+ return qtensor
+
+ def _dequantize(self, qtensor):
+ return qtensor.dequantize()
+
+
+class HQQQuantizedLayer(QuantizedLayer):
+ def __init__(
+ self,
+ nbits: int = 4,
+ axis_key: int = 0,
+ axis_value: int = 0,
+ q_group_size: int = 64,
+ residual_length: int = 128,
+ ):
+ super().__init__(
+ nbits=nbits,
+ axis_key=axis_key,
+ axis_value=axis_value,
+ q_group_size=q_group_size,
+ residual_length=residual_length,
+ )
+
+ if not is_hqq_available():
+ raise ImportError("You need to install `hqq` to use `HQQQuantizedLayer`")
+
+ if self.nbits not in [1, 2, 3, 4, 8]:
+ raise ValueError(
+ f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}"
+ )
+
+ if self.axis_key not in [0, 1]:
+ raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}")
+
+ if self.axis_value not in [0, 1]:
+ raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}")
+
+ self.quantizer = HQQQuantizer
+
+ def _quantize(self, tensor, axis):
+ qtensor, meta = self.quantizer.quantize(
+ tensor,
+ axis=axis,
+ device=self.keys.device,
+ compute_dtype=self.keys.dtype,
+ nbits=self.nbits,
+ group_size=self.q_group_size,
+ )
+ meta["compute_dtype"] = self.keys.dtype
+ self.quantizer.cuda(qtensor, meta=meta, device=self.keys.device) # Move to device and cast to dtype
+ meta["scale"] = meta["scale"].to(qtensor.device)
+ meta["zero"] = meta["zero"].to(qtensor.device)
+ return qtensor, meta
+
+ def _dequantize(self, qtensor):
+ quant_tensor, meta = qtensor
+ tensor = self.quantizer.dequantize(quant_tensor, meta)
+ return tensor
+
+
+class Cache:
+ """
+ A `Cache` is mostly a list of `CacheLayerMixin` objects, one per model layer. It serves as a container for
+ the Cache of each layer.
+
+ Args:
+ layers (`Optional`, *optional*):
+ A list of pre-created `CacheLayerMixin`. If omitted (`None`), then `layer_class_to_replicate` will
+ be used.
+ layer_class_to_replicate (`type[CacheLayerMixin]`, *optional*):
+ Only used if `layers` is omitted (`None`), in which case it will be used as the base class for each layer,
+ and the layers will be added lazily as soon as `update` is called with a `layer_idx` greater than the current
+ list of layers.
+ offloading (`bool`, *optional*, defaults to `False`):
+ Whether to perform offloading of the layers to `cpu`, to save GPU memory.
+ offload_only_non_sliding (`bool`, *optional*, defaults to `True`):
+ If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
+ usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
+ """
+
+ def __init__(
+ self,
+ layers: Optional[list[CacheLayerMixin]] = None,
+ layer_class_to_replicate: Optional[type[CacheLayerMixin]] = None,
+ offloading: bool = False,
+ offload_only_non_sliding: bool = True,
+ ):
+ if layers is not None and layer_class_to_replicate is not None:
+ raise ValueError(
+ "You can construct a Cache either from a list `layers` of all the predefined `CacheLayer`, or from a "
+ "`layer_class_to_replicate`, in which case the Cache will append a new layer corresponding to "
+ "`layer_class_to_replicate` for each new call to `update` with an idx not already in the Cache."
+ )
+ if layers is None and layer_class_to_replicate is None:
+ raise ValueError(
+ "You should provide exactly one of `layers` or `layer_class_to_replicate` to initialize a Cache."
+ )
+ self.layers = layers if layers is not None else []
+ self.layer_class_to_replicate = layer_class_to_replicate
+ self.offloading = offloading
+ if self.offloading:
+ self.only_non_sliding = offload_only_non_sliding
+ self.prefetch_stream = torch.Stream() if _is_torch_greater_or_equal_than_2_7 else torch.cuda.Stream()
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}(layers={self.layers})"
+
+ def prefetch(self, layer_idx: int, only_non_sliding: bool = True):
+ """
+ Prefetch a given layer on its device. If `only_non_sliding` is True, it will try to prefetch only the layers
+ which are non-sliding. If the `layer_idx` is outside the range, this will circle back to the first layers.
+ Note that we use a non-default stream for this, to avoid blocking.
+ """
+ if only_non_sliding:
+ # Try to find next non-sliding, starting at `layer_idx`
+ try:
+ layer_idx = layer_idx + self.is_sliding[layer_idx:].index(False)
+ # In this case, we need to circle back to the beginning
+ except ValueError:
+ layer_idx = self.is_sliding.index(False)
+ else:
+ layer_idx = layer_idx if layer_idx < len(self.layers) else 0
+
+ # Prefetch
+ with self.prefetch_stream if _is_torch_greater_or_equal_than_2_7 else torch.cuda.stream(self.prefetch_stream):
+ self.layers[layer_idx].prefetch()
+
+ def offload(self, layer_idx: int, only_non_sliding: bool = True):
+ """
+ Offload a given `layer_idx`. If `only_non_sliding` is True, it will offload `layer_idx` only if it is a
+ non-sliding layer. Note that we do it on the default stream, so that we ensure all earlier
+ computation in the layer's `update` methods are finished.
+ """
+ if not (only_non_sliding and self.is_sliding[layer_idx]):
+ self.layers[layer_idx].offload()
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[dict[str, Any]] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+ Parameters:
+ key_states (`torch.Tensor`):
+ The new key states to cache.
+ value_states (`torch.Tensor`):
+ The new value states to cache.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ cache_kwargs (`dict[str, Any]`, *optional*):
+ Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
+ cache to be created.
+
+ Return:
+ A tuple containing the updated key and value states.
+ """
+ # In this case, the `layers` were not provided, and we must append as much as `layer_idx`
+ if self.layer_class_to_replicate is not None:
+ while len(self.layers) <= layer_idx:
+ self.layers.append(self.layer_class_to_replicate())
+
+ if self.offloading:
+ # Wait for the stream to finish if needed, and start prefetching the next layer
+ torch.cuda.default_stream(key_states.device).wait_stream(self.prefetch_stream)
+ self.prefetch(layer_idx + 1, self.only_non_sliding)
+
+ keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
+
+ if self.offloading:
+ self.offload(layer_idx, self.only_non_sliding)
+
+ return keys, values
+
+ def early_initialization(
+ self, batch_size: int, num_heads: int, head_dim: int, dtype: torch.dtype, device: torch.device
+ ):
+ """
+ Initialize all the layers in advance (it's otherwise lazily initialized on the first `update` call).
+ This is useful for our `export` recipes, as `export` needs everything in advance.
+ """
+ # Note that the initialization needs all dimensions (except -2), as well as device and dtype, so we use
+ # this fake tensor approach. It has size 0 on the -2 dimension, so it does not allocate any data (it only
+ # creates an empty tensor with correct shape, dtype and device), which is very efficient and practical
+ fake_keys_tensor = torch.zeros((batch_size, num_heads, 0, head_dim), dtype=dtype, device=device)
+ # Init all layers
+ for layer in self.layers:
+ layer.lazy_initialization(fake_keys_tensor)
+
+ def get_seq_length(self, layer_idx: int = 0) -> int:
+ """Returns the sequence length of the cache for the given layer."""
+ if layer_idx >= len(self.layers):
+ return 0
+ return self.layers[layer_idx].get_seq_length()
+
+ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
+ """
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
+ the given layer at `layer_idx`.
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.
+ """
+ # For DynamicCache, where the layers are created at runtime -> if it was not yet created, the size is
+ # simply the shape of `cache_position`
+ if layer_idx >= len(self.layers):
+ return cache_position.shape[0], 0
+ return self.layers[layer_idx].get_mask_sizes(cache_position)
+
+ def get_max_cache_shape(self, layer_idx: int = 0) -> int:
+ """Returns maximum sequence length of the cache object. Dynamic caches do not have a maximum length."""
+ # For DynamicCache, where the layers are created at runtime -> if it was not yet created, return -1
+ # as DynamicLayer does
+ if layer_idx >= len(self.layers):
+ return -1
+ return self.layers[layer_idx].get_max_cache_shape()
+
+ def reset(self):
+ """Recursively reset all layers tensors"""
+ for layer_idx in range(len(self.layers)):
+ self.layers[layer_idx].reset()
+
+ def reorder_cache(self, beam_idx: torch.LongTensor):
+ """Reorder the cache for beam search"""
+ for layer_idx in range(len(self.layers)):
+ self.layers[layer_idx].reorder_cache(beam_idx)
+
+ def crop(self, max_length: int):
+ """Crop the cache to the given length"""
+ for layer_idx in range(len(self.layers)):
+ self.layers[layer_idx].crop(max_length)
+
+ def batch_repeat_interleave(self, repeats: int):
+ """Repeat and interleave the cache"""
+ for layer_idx in range(len(self.layers)):
+ self.layers[layer_idx].batch_repeat_interleave(repeats)
+
+ def batch_select_indices(self, indices: torch.Tensor):
+ """Select indices from the cache"""
+ for layer_idx in range(len(self.layers)):
+ self.layers[layer_idx].batch_select_indices(indices)
+
+ @property
+ def max_batch_size(self) -> int:
+ """Return the maximum batch size of the cache"""
+ values = [layer.max_batch_size for layer in self.layers]
+ if len(set(values)) > 1:
+ raise ValueError(f"Max batch size is not consistent across layers: {values}")
+ return values[0]
+
+ @property
+ def max_cache_len(self) -> int:
+ """Return the maximum cache length of the cache"""
+ values = [layer.max_cache_len for layer in self.layers]
+ return max(values)
+
+ @property
+ def is_compileable(self) -> bool:
+ """Return whether the cache is compileable"""
+ # For DynamicCache dispatching the layers lazily (otherwise, all([]) is True)
+ if len(self.layers) == 0:
+ return False
+ return all(layer.is_compileable for layer in self.layers)
+
+ @property
+ def is_initialized(self) -> bool:
+ """Return whether the cache data is initialized"""
+ return len(self.layers) > 0 and all(layer.is_initialized for layer in self.layers)
+
+ @property
+ def is_sliding(self) -> list[bool]:
+ """Return whether the layers of the cache are sliding window"""
+ return [getattr(layer, "is_sliding", False) for layer in self.layers]
+
+ def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the
+ sequence length.
+ """
+ if layer_idx < len(self.layers):
+ return self.layers[layer_idx].keys, self.layers[layer_idx].values
+ else:
+ raise KeyError(
+ f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}"
+ )
+
+ def __iter__(self):
+ """
+ Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over
+ keys and values
+ """
+ for layer_idx in range(len(self)):
+ yield (self.layers[layer_idx].keys, self.layers[layer_idx].values)
+
+ def __len__(self):
+ """
+ This value corresponds to the number of layers in the model.
+ """
+ # Note: for DynamicCache, layers are initialized lazily, so this will not be accurate before the first
+ # forward through all the layers
+ return len(self.layers)
+
+
+class DynamicCache(Cache):
+ """
+ A cache that grows dynamically as more tokens are generated. This is the default for generative models.
+ It stores the key and value states as a list of `CacheLayer`, one for each layer. The expected shape for each tensor
+ in the `CacheLayer`s is `[batch_size, num_heads, seq_len, head_dim]`.
+ If a config is passed, it will additionally check for sliding or hybrid cache structure, greatly reducing the
+ memory requirement of the cached tensors to `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
+
+ See `Cache` for details on common methods that are implemented by all cache classes.
+
+ Args:
+ ddp_cache_data (`Iterable[tuple[torch.Tensor, torch.Tensor]]`, *optional*):
+ It was originally added for compatibility with `torch.distributed` (DDP). In a nutshell, it is
+ `map(gather_map, zip(*caches))`, i.e. each item in the iterable contains the key and value states
+ for a layer gathered across replicas by torch.distributed (shape=[global batch size, num_heads, seq_len, head_dim]).
+ Note: it needs to be the 1st arg as well to work correctly
+ config (`PretrainedConfig`, *optional*):
+ The config of the model for which this Cache will be used. If passed, it will be used to check for sliding
+ or hybrid layer structure, greatly reducing the memory requirement of the cached tensors to
+ `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
+ offloading (`bool`, *optional*, defaults to `False`):
+ Whether to perform offloading of the layers to `cpu`, to save GPU memory.
+ offload_only_non_sliding (`bool`, *optional*, defaults to `False`):
+ If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
+ usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
+
+ >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+
+ >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> past_key_values = DynamicCache(config=model.config)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> outputs.past_key_values # access cache filled with key/values from generation
+ ```
+ """
+
+ def __init__(
+ self,
+ ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None,
+ config: Optional[PretrainedConfig] = None,
+ offloading: bool = False,
+ offload_only_non_sliding: bool = False,
+ ):
+ layers = []
+ # If a config is passed, use it to infer the layer types and initialize accordingly
+ if config is not None:
+ decoder_config = config.get_text_config(decoder=True)
+ sliding_window = getattr(decoder_config, "sliding_window", None) or getattr(
+ decoder_config, "attention_chunk_size", None
+ )
+ layer_types = getattr(decoder_config, "layer_types", None)
+ if layer_types is None:
+ layer_types = [
+ "sliding_attention" if sliding_window is not None else "full_attention"
+ for _ in range(decoder_config.num_hidden_layers)
+ ]
+ # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
+ if hasattr(decoder_config, "num_kv_shared_layers"):
+ layer_types = layer_types[: -decoder_config.num_kv_shared_layers]
+
+ for layer_type in layer_types:
+ # From a cache point of view, both sliding and chunked are the same in how they should behave and how many
+ # states they should return - only the mask changes to make them different at the end!
+ if layer_type in ("sliding_attention", "chunked_attention"):
+ layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window))
+ else:
+ layers.append(DynamicLayer())
+
+ # In this case, use the passed data to already fill in the Cache
+ if ddp_cache_data is not None:
+ # Init all the layers with the data
+ for layer_idx, (key_states, value_states) in enumerate(ddp_cache_data):
+ # If the config was not passed above, initialize a DynamicLayer for each entry of the ddp_data
+ if config is None:
+ layers.append(DynamicLayer())
+ # Update the layer with the data
+ _, _ = layers[layer_idx].update(key_states, value_states)
+
+ # If neither of config nor ddp_data was passed, then simply lazy init a full cache of DynamicLayer
+ if len(layers) == 0:
+ super().__init__(
+ layer_class_to_replicate=DynamicLayer,
+ offloading=offloading,
+ offload_only_non_sliding=offload_only_non_sliding,
+ )
+ else:
+ super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding)
+
+ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for
+ backward compatibility.
+ """
+ legacy_cache = ()
+ for layer in self.layers:
+ legacy_cache += ((layer.keys, layer.values),)
+ return legacy_cache
+
+ @classmethod
+ def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.Tensor, torch.Tensor]]) -> "DynamicCache":
+ """
+ Converts a cache in the legacy cache format into an equivalent `Cache`. Used for
+ backward compatibility.
+ """
+ cache = cls()
+ if past_key_values is None:
+ logger.warning_once("past_key_values should not be None in from_legacy_cache()")
+ if past_key_values is not None:
+ for layer_idx in range(len(past_key_values)):
+ key_states, value_states = past_key_values[layer_idx]
+ cache.update(key_states, value_states, layer_idx)
+ return cache
+
+
+class StaticCache(Cache):
+ """
+ Static Cache class to be used with `torch.compile(model)` and `torch.export()`. It will check the `config`
+ for potential hybrid cache structure, and initialize each layer accordingly.
+
+ See `Cache` for details on common methods that are implemented by all cache classes.
+
+ Args:
+ config (`PretrainedConfig`):
+ The config of the model for which this Cache will be used. It will be used to check for sliding
+ or hybrid layer structure, and initialize each layer accordingly.
+ max_cache_len (`int`):
+ The maximum number of tokens that this Cache should hold.
+ offloading (`bool`, *optional*, defaults to `False`):
+ Whether to perform offloading of the layers to `cpu`, to save GPU memory.
+ offload_only_non_sliding (`bool`, *optional*, defaults to `True`):
+ If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
+ usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
+
+ >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
+
+ >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
+ >>> max_generated_length = inputs.input_ids.shape[1] + 10
+ >>> past_key_values = StaticCache(config=model.config, max_cache_len=max_generated_length)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> outputs.past_key_values # access cache filled with key/values from generation
+ StaticCache()
+ ```
+ """
+
+ # Pass-in kwargs as well to avoid crashing for BC (it used more arguments before)
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ max_cache_len: int,
+ offloading: bool = False,
+ offload_only_non_sliding: bool = True,
+ **kwargs,
+ ):
+ config = config.get_text_config(decoder=True)
+ layer_types = getattr(config, "layer_types", None)
+ # If `layer_types` is not explicitly provided, infer if the model is fully sliding
+ if layer_types is None:
+ if getattr(config, "sliding_window", None) is not None:
+ layer_types = ["sliding_attention" for _ in range(config.num_hidden_layers)]
+ elif getattr(config, "attention_chunk_size", None) is not None:
+ layer_types = ["chunked_attention" for _ in range(config.num_hidden_layers)]
+ else:
+ layer_types = ["full_attention" for _ in range(config.num_hidden_layers)]
+ # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
+ if hasattr(config, "num_kv_shared_layers"):
+ layer_types = layer_types[: -config.num_kv_shared_layers]
+
+ layers = []
+ for layer_type in layer_types:
+ if layer_type == "sliding_attention":
+ layer = StaticSlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.sliding_window)
+ elif layer_type == "chunked_attention":
+ # From a cache point of view, both sliding and chunked are the same in how they should behave and how many
+ # states they should return - only the mask changes to make them different at the end!
+ layer = StaticSlidingWindowLayer(
+ max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size
+ )
+ else:
+ layer = StaticLayer(max_cache_len=max_cache_len)
+ layers.append(layer)
+
+ super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding)
+
+
+class QuantizedCache(Cache):
+ """
+ A quantizer cache similar to what is described in the
+ [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
+ It allows the model to generate longer sequence length without allocating too much memory for keys and values
+ by applying quantization.
+ The cache has two types of storage, one for original precision and one for the
+ quantized cache. A `residual length` is set as a maximum capacity for the original precision cache. When the
+ length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache.
+ The quantization is done per-channel with a set `q_group_size` for both keys and values, in contrast to what was
+ described in the paper.
+
+ See `Cache` for details on common methods that are implemented by all cache classes.
+
+ Args:
+ backend (`str`):
+ The quantization backend to use. One of `("quanto", "hqq").
+ config (`PretrainedConfig`):
+ The config of the model for which this Cache will be used.
+ nbits (`int`, *optional*, defaults to 4):
+ The number of bits for quantization.
+ axis_key (`int`, *optional*, defaults to 0):
+ The axis on which to quantize the keys.
+ axis_value (`int`, *optional*, defaults to 0):
+ The axis on which to quantize the values.
+ q_group_size (`int`, *optional*, defaults to 64):
+ Quantization is done per-channel according to a set `q_group_size` for both keys and values.
+ residual_length (`int`, *optional*, defaults to 128):
+ Maximum capacity for the original precision cache
+ """
+
+ def __init__(
+ self,
+ backend: str,
+ config: PretrainedConfig,
+ nbits: int = 4,
+ axis_key: int = 0,
+ axis_value: int = 0,
+ q_group_size: int = 64,
+ residual_length: int = 128,
+ ):
+ if backend == "quanto":
+ layer_class = QuantoQuantizedLayer
+ elif backend == "hqq":
+ layer_class = HQQQuantizedLayer
+ else:
+ raise ValueError(f"Unknown quantization backend `{backend}`")
+
+ config = config.get_text_config(decoder=True)
+ layers = [
+ layer_class(nbits, axis_key, axis_value, q_group_size, residual_length)
+ for _ in range(config.num_hidden_layers)
+ ]
+ super().__init__(layers=layers)
+
+
+class EncoderDecoderCache(Cache):
+ """
+ Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
+ cross-attention caches.
+
+ See `Cache` for details on common methods that are implemented by all cache classes.
+
+ Args:
+ caches (`Iterable`):
+ Usually an iterable of length 2, containing 2 `Cache` objects, the first one for self-attention, the
+ second one for cross-attention. Can optionally also be an iterable of length 1, containing a
+ `tuple[tuple[torch.Tensor]]` (usually used for compatibility with torch dp and ddp).
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache
+
+ >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small")
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-small")
+
+ >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt")
+
+ >>> # Prepare cache classes for encoder and decoder and pass it to model's forward
+ >>> self_attention_cache = DynamicCache(config=self.config)
+ >>> cross_attention_cache = DynamicCache(config=self.config)
+ >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> outputs.past_key_values # access cache filled with key/values from generation
+ EncoderDecoderCache()
+ ```
+ """
+
+ def __init__(self, *caches) -> None:
+ # For dp and ddp support, if only one argument is passed, it should be an iterable of tuples of tensors
+ if len(caches) == 1:
+ self.self_attention_cache = DynamicCache()
+ self.cross_attention_cache = DynamicCache()
+ # Populate cache from the iterable
+ for layer_idx, key_value_states in enumerate(caches[0]):
+ key_states, value_states = key_value_states[:2]
+ self.self_attention_cache.update(key_states, value_states, layer_idx)
+ if len(key_value_states) > 2:
+ key_states, value_states = key_value_states[2:]
+ self.cross_attention_cache.update(key_states, value_states, layer_idx)
+ # Otherwise, we should get two arguments, a self-attention cache and a cross-attention cache
+ elif len(caches) == 2:
+ if not isinstance(caches[0], Cache) or not isinstance(caches[1], Cache):
+ raise TypeError(f"One of the two arguments is not a Cache: {type(caches[0]) = }, {type(caches[1]) = }")
+ self.self_attention_cache = caches[0]
+ self.cross_attention_cache = caches[1]
+ # Error case
+ else:
+ raise ValueError(f"Expected 1 or 2 arguments, got {len(caches)}")
+
+ self.is_updated = {}
+ for layer_idx in range(len(self.cross_attention_cache)):
+ self.is_updated[layer_idx] = bool(self.cross_attention_cache.get_seq_length(layer_idx) > 0)
+
+ def __repr__(self) -> str:
+ return (
+ f"{self.__class__.__name__}(self_attention_cache={self.self_attention_cache}, cross_attention_cache="
+ f"{self.cross_attention_cache})"
+ )
+
+ def __iter__(self):
+ """
+ Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over
+ keys and values
+ """
+ for layer_idx in range(len(self)):
+ yield (
+ self.self_attention_cache.layers[layer_idx].keys,
+ self.self_attention_cache.layers[layer_idx].values,
+ self.cross_attention_cache.layers[layer_idx].keys,
+ self.cross_attention_cache.layers[layer_idx].values,
+ )
+
+ def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the
+ sequence length.
+ """
+ if layer_idx < len(self):
+ return (
+ self.self_attention_cache.layers[layer_idx].keys,
+ self.self_attention_cache.layers[layer_idx].values,
+ self.cross_attention_cache.layers[layer_idx].keys,
+ self.cross_attention_cache.layers[layer_idx].values,
+ )
+ else:
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
+
+ def __len__(self):
+ """
+ Support for backwards-compatible `past_key_values` length, e.g. `len(past_key_values)`. This value corresponds
+ to the number of layers in the model.
+ """
+ return len(self.self_attention_cache)
+
+ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor]]:
+ """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
+ legacy_cache = ()
+ if len(self.cross_attention_cache) > 0:
+ for self_attn, cross_attn in zip(
+ self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache()
+ ):
+ legacy_cache += (self_attn + cross_attn,)
+ else:
+ legacy_cache = self.self_attention_cache.to_legacy_cache()
+ return legacy_cache
+
+ @classmethod
+ def from_legacy_cache(
+ cls, past_key_values: Optional[Iterable[tuple[torch.FloatTensor, ...]]]
+ ) -> "EncoderDecoderCache":
+ """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
+ cache = cls(DynamicCache(), DynamicCache())
+ if past_key_values is None:
+ logger.warning_once("past_key_values should not be None in from_legacy_cache()")
+ else:
+ for layer_idx, key_value_states in enumerate(past_key_values):
+ key_states, value_states = key_value_states[:2]
+ cache.self_attention_cache.update(key_states, value_states, layer_idx)
+ if len(key_value_states) > 2:
+ key_states, value_states = key_value_states[2:]
+ cache.cross_attention_cache.update(key_states, value_states, layer_idx)
+ cache.is_updated[layer_idx] = True
+ return cache
+
+ def get_seq_length(self, layer_idx: int = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ return self.self_attention_cache.get_seq_length(layer_idx)
+
+ def reset(self):
+ self.self_attention_cache.reset()
+ self.cross_attention_cache.reset()
+ for layer_idx in self.is_updated:
+ self.is_updated[layer_idx] = False
+
+ def reorder_cache(self, beam_idx: torch.LongTensor):
+ """Reorders the cache for beam search, given the selected beam indices."""
+ self.self_attention_cache.reorder_cache(beam_idx)
+ self.cross_attention_cache.reorder_cache(beam_idx)
+
+ def check_dynamic_cache(self, method: str):
+ if not (
+ isinstance(self.self_attention_cache, DynamicCache)
+ and isinstance(self.cross_attention_cache, DynamicCache)
+ ):
+ raise ValueError(
+ f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
+ f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
+ )
+
+ # TODO(gante, sanchit-gandhi): move following functionality into `.generate`
+ def crop(self, maximum_length: int):
+ """
+ Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
+ negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search (on the Hub).
+ """
+ self.check_dynamic_cache(self.crop.__name__)
+ self.self_attention_cache.crop(maximum_length)
+
+ def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]":
+ """
+ Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
+ `_split_model_inputs()` in `generation.utils`
+ """
+ self.check_dynamic_cache(self.batch_split.__name__)
+ self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
+ cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
+
+ out = []
+ for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
+ out.append(EncoderDecoderCache(self_attn, cross_attn))
+ return out
+
+ def batch_repeat_interleave(self, repeats: int):
+ """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search (on the Hub)."""
+ self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
+ self.self_attention_cache.batch_repeat_interleave(repeats)
+ self.cross_attention_cache.batch_repeat_interleave(repeats)
+
+ def batch_select_indices(self, indices: torch.Tensor):
+ """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search (on the Hub)."""
+ self.check_dynamic_cache(self.batch_select_indices.__name__)
+ self.self_attention_cache.batch_select_indices(indices)
+ self.cross_attention_cache.batch_select_indices(indices)
+
+ def get_max_cache_shape(self) -> int:
+ """Returns the maximum sequence length (i.e. max capacity) of the cache object"""
+ return self.self_attention_cache.get_max_cache_shape()
+
+ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
+ return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx)
+
+ @property
+ def is_sliding(self):
+ return self.self_attention_cache.is_sliding
+
+ @property
+ def is_compileable(self) -> bool:
+ return self.self_attention_cache.is_compileable
+
+
+### Deprecated classes
+
+
+class SlidingWindowLayer(StaticSlidingWindowLayer):
+ def __init__(self, max_cache_len: int, sliding_window: int):
+ logger.warning_once(
+ "`SlidingWindowLayer` is deprecated and will be removed in version v4.59 "
+ "Use `StaticSlidingWindowLayer` instead, which is a better name for it."
+ )
+ super().__init__(max_cache_len, sliding_window)
+
+
+class ChunkedSlidingLayer(StaticSlidingWindowLayer):
+ def __init__(self, max_cache_len: int, sliding_window: int):
+ logger.warning_once(
+ "`ChunkedSlidingLayer` is deprecated and will be removed in version v4.59 "
+ "Use `StaticSlidingWindowLayer` instead, which has the exact same functionalities."
+ )
+ super().__init__(max_cache_len, sliding_window)
+
+
+class OffloadedCache(DynamicCache):
+ def __init__(self) -> None:
+ logger.warning_once(
+ "`OffloadedCache` is deprecated and will be removed in version v4.59 "
+ "Use `DynamicCache(offloading=True)` instead"
+ )
+ super().__init__(offloading=True)
+
+
+class OffloadedStaticCache(StaticCache):
+ def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
+ logger.warning_once(
+ "`OffloadedStaticCache` is deprecated and will be removed in version v4.59 "
+ "Use `StaticCache(..., offloading=True)` instead"
+ )
+ super().__init__(config=config, max_cache_len=max_cache_len, offloading=True)
+
+
+class SlidingWindowCache(StaticCache):
+ def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
+ logger.warning_once(
+ "`SlidingWindowCache` is deprecated and will be removed in version v4.59 "
+ "Use `StaticCache(...)` instead which will correctly infer the type of each layer."
+ )
+ super().__init__(config=config, max_cache_len=max_cache_len)
+
+
+class HybridCache(StaticCache):
+ def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
+ logger.warning_once(
+ "`HybridCache` is deprecated and will be removed in version v4.59 "
+ "Use `StaticCache(...)` instead which will correctly infer the type of each layer."
+ )
+ super().__init__(config=config, max_cache_len=max_cache_len)
+
+
+class HybridChunkedCache(StaticCache):
+ def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
+ logger.warning_once(
+ "`HybridChunkedCache` is deprecated and will be removed in version v4.59 "
+ "Use `StaticCache(...)` instead which will correctly infer the type of each layer."
+ )
+ super().__init__(config=config, max_cache_len=max_cache_len)
+
+
+class OffloadedHybridCache(StaticCache):
+ def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
+ logger.warning_once(
+ "`OffloadedHybridCache` is deprecated and will be removed in version v4.59 "
+ "Use `StaticCache(..., offload=True)` instead which will correctly infer the type of each layer."
+ )
+ super().__init__(config=config, max_cache_len=max_cache_len, offloading=True)
+
+
+class QuantoQuantizedCache(QuantizedCache):
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ nbits: int = 4,
+ axis_key: int = 0,
+ axis_value: int = 0,
+ q_group_size: int = 64,
+ residual_length: int = 128,
+ ):
+ logger.warning_once(
+ "`QuantoQuantizedCache` is deprecated and will be removed in version v4.59 "
+ "Use `QuantizedCache(backend='quanto', ...)` instead."
+ )
+ super().__init__("quanto", config, nbits, axis_key, axis_value, q_group_size, residual_length)
+
+
+class HQQQuantizedCache(QuantizedCache):
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ nbits: int = 4,
+ axis_key: int = 0,
+ axis_value: int = 0,
+ q_group_size: int = 64,
+ residual_length: int = 128,
+ ):
+ logger.warning_once(
+ "`HQQQuantizedCache` is deprecated and will be removed in version v4.59 "
+ "Use `QuantizedCache(backend='hqq', ...)` instead."
+ )
+ super().__init__("hqq", config, nbits, axis_key, axis_value, q_group_size, residual_length)
+
+
+class SinkCache(Cache):
+ """
+ It is now a `custom_generate` repository on the Hub: https://huggingface.co/transformers-community/sink_cache.
+ See [these docs](https://huggingface.co/docs/transformers/generation_strategies#custom-decoding-methods) for
+ general `custom_generate`usage.
+ """
+
+ # TODO (joao, manuel): Remove this class in v4.59.0
+ def __init__(self, **kwargs) -> None:
+ raise NotImplementedError(
+ "`SinkCache` has been moved as a `custom_generate` repository on the Hub: "
+ "https://huggingface.co/transformers-community/sink_cache. See the repository for usage examples."
+ )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2c825a45b605c271c3912bf5b4f3ff4b5c1a32c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py
@@ -0,0 +1,86 @@
+# Copyright 2020 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert Seq2Seq TF Hub checkpoint."""
+
+import argparse
+
+from . import (
+ BertConfig,
+ BertGenerationConfig,
+ BertGenerationDecoder,
+ BertGenerationEncoder,
+ load_tf_weights_in_bert_generation,
+ logging,
+)
+
+
+logging.set_verbosity_info()
+
+
+def convert_tf_checkpoint_to_pytorch(tf_hub_path, pytorch_dump_path, is_encoder_named_decoder, vocab_size, is_encoder):
+ # Initialise PyTorch model
+ bert_config = BertConfig.from_pretrained(
+ "google-bert/bert-large-cased",
+ vocab_size=vocab_size,
+ max_position_embeddings=512,
+ is_decoder=True,
+ add_cross_attention=True,
+ )
+ bert_config_dict = bert_config.to_dict()
+ del bert_config_dict["type_vocab_size"]
+ config = BertGenerationConfig(**bert_config_dict)
+ if is_encoder:
+ model = BertGenerationEncoder(config)
+ else:
+ model = BertGenerationDecoder(config)
+ print(f"Building PyTorch model from configuration: {config}")
+
+ # Load weights from tf checkpoint
+ load_tf_weights_in_bert_generation(
+ model,
+ tf_hub_path,
+ model_class="bert",
+ is_encoder_named_decoder=is_encoder_named_decoder,
+ is_encoder=is_encoder,
+ )
+
+ # Save pytorch-model
+ print(f"Save PyTorch model and config to {pytorch_dump_path}")
+ model.save_pretrained(pytorch_dump_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--tf_hub_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
+ )
+ parser.add_argument(
+ "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
+ )
+ parser.add_argument(
+ "--is_encoder_named_decoder",
+ action="store_true",
+ help="If decoder has to be renamed to encoder in PyTorch model.",
+ )
+ parser.add_argument("--is_encoder", action="store_true", help="If model is an encoder.")
+ parser.add_argument("--vocab_size", default=50358, type=int, help="Vocab size of model")
+ args = parser.parse_args()
+ convert_tf_checkpoint_to_pytorch(
+ args.tf_hub_path,
+ args.pytorch_dump_path,
+ args.is_encoder_named_decoder,
+ args.vocab_size,
+ is_encoder=args.is_encoder,
+ )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/debug_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/debug_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..920b1cf44dafd320729eef1e6a36b9a41741c83c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/debug_utils.py
@@ -0,0 +1,346 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections
+
+from .utils import ExplicitEnum, is_torch_available, logging
+
+
+if is_torch_available():
+ import torch
+
+
+logger = logging.get_logger(__name__)
+
+
+class DebugUnderflowOverflow:
+ """
+ This debug class helps detect and understand where the model starts getting very large or very small, and more
+ importantly `nan` or `inf` weight and activation elements.
+
+ There are 2 working modes:
+
+ 1. Underflow/overflow detection (default)
+ 2. Specific batch absolute min/max tracing without detection
+
+ Mode 1: Underflow/overflow detection
+
+ To activate the underflow/overflow detection, initialize the object with the model :
+
+ ```python
+ debug_overflow = DebugUnderflowOverflow(model)
+ ```
+
+ then run the training as normal and if `nan` or `inf` gets detected in at least one of the weight, input or output
+ elements this module will throw an exception and will print `max_frames_to_save` frames that lead to this event,
+ each frame reporting
+
+ 1. the fully qualified module name plus the class name whose `forward` was run
+ 2. the absolute min and max value of all elements for each module weights, and the inputs and output
+
+ For example, here is the header and the last few frames in detection report for `google/mt5-small` run in fp16
+ mixed precision :
+
+ ```
+ Detected inf/nan during batch_number=0
+ Last 21 forward frames:
+ abs min abs max metadata
+ [...]
+ encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
+ 2.17e-07 4.50e+00 weight
+ 1.79e-06 4.65e+00 input[0]
+ 2.68e-06 3.70e+01 output
+ encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
+ 8.08e-07 2.66e+01 weight
+ 1.79e-06 4.65e+00 input[0]
+ 1.27e-04 2.37e+02 output
+ encoder.block.2.layer.1.DenseReluDense.wo Linear
+ 1.01e-06 6.44e+00 weight
+ 0.00e+00 9.74e+03 input[0]
+ 3.18e-04 6.27e+04 output
+ encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
+ 1.79e-06 4.65e+00 input[0]
+ 3.18e-04 6.27e+04 output
+ encoder.block.2.layer.1.dropout Dropout
+ 3.18e-04 6.27e+04 input[0]
+ 0.00e+00 inf output
+ ```
+
+ You can see here, that `T5DenseGatedGeluDense.forward` resulted in output activations, whose absolute max value was
+ around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have `Dropout` which
+ renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than
+ 64K, and we get an overflow.
+
+ As you can see it's the previous frames that we need to look into when the numbers start going into very large for
+ fp16 numbers.
+
+ The tracking is done in a forward hook, which gets invoked immediately after `forward` has completed.
+
+ By default the last 21 frames are printed. You can change the default to adjust for your needs. For example :
+
+ ```python
+ debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)
+ ```
+
+ To validate that you have set up this debugging feature correctly, and you intend to use it in a training that
+ may take hours to complete, first run it with normal tracing enabled for one of a few batches as explained in
+ the next section.
+
+
+ Mode 2. Specific batch absolute min/max tracing without detection
+
+ The second work mode is per-batch tracing with the underflow/overflow detection feature turned off.
+
+ Let's say you want to watch the absolute min and max values for all the ingredients of each `forward` call of a
+ given batch, and only do that for batches 1 and 3. Then you instantiate this class as :
+
+ ```python
+ debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3])
+ ```
+
+ And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed.
+
+ This is helpful if you know that the program starts misbehaving after a certain batch number, so you can
+ fast-forward right to that area.
+
+
+ Early stopping:
+
+ You can also specify the batch number after which to stop the training, with :
+
+ ```python
+ debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3], abort_after_batch_num=3)
+ ```
+
+ This feature is mainly useful in the tracing mode, but you can use it for any mode.
+
+
+ **Performance**:
+
+ As this module measures absolute `min`/``max` of each weight of the model on every forward it'll slow the training
+ down. Therefore remember to turn it off once the debugging needs have been met.
+
+ Args:
+ model (`nn.Module`):
+ The model to debug.
+ max_frames_to_save (`int`, *optional*, defaults to 21):
+ How many frames back to record
+ trace_batch_nums(`list[int]`, *optional*, defaults to `[]`):
+ Which batch numbers to trace (turns detection off)
+ abort_after_batch_num (`int``, *optional*):
+ Whether to abort after a certain batch number has finished
+ """
+
+ def __init__(self, model, max_frames_to_save=21, trace_batch_nums=[], abort_after_batch_num=None):
+ self.model = model
+ self.trace_batch_nums = trace_batch_nums
+ self.abort_after_batch_num = abort_after_batch_num
+
+ # keep a LIFO buffer of frames to dump as soon as inf/nan is encountered to give context to the problem emergence
+ self.frames = collections.deque([], max_frames_to_save)
+ self.frame = []
+ self.batch_number = 0
+ self.total_calls = 0
+ self.detected_overflow = False
+ self.prefix = " "
+
+ self.analyse_model()
+
+ self.register_forward_hook()
+
+ def save_frame(self, frame=None):
+ if frame is not None:
+ self.expand_frame(frame)
+ self.frames.append("\n".join(self.frame))
+ self.frame = [] # start a new frame
+
+ def expand_frame(self, line):
+ self.frame.append(line)
+
+ def trace_frames(self):
+ print("\n".join(self.frames))
+ self.frames = []
+
+ def reset_saved_frames(self):
+ self.frames = []
+
+ def dump_saved_frames(self):
+ print(f"\nDetected inf/nan during batch_number={self.batch_number}")
+ print(f"Last {len(self.frames)} forward frames:")
+ print(f"{'abs min':8} {'abs max':8} metadata")
+ print("\n".join(self.frames))
+ print("\n\n")
+ self.frames = []
+
+ def analyse_model(self):
+ # extract the fully qualified module names, to be able to report at run time. e.g.:
+ # encoder.block.2.layer.0.SelfAttention.o
+ #
+ # for shared weights only the first shared module name will be registered
+ self.module_names = {m: name for name, m in self.model.named_modules()}
+ # self.longest_module_name = max(len(v) for v in self.module_names.values())
+
+ def analyse_variable(self, var, ctx):
+ if torch.is_tensor(var):
+ self.expand_frame(get_abs_min_max(var, ctx))
+ if detect_overflow(var, ctx):
+ self.detected_overflow = True
+ elif var is None:
+ self.expand_frame(f"{'None':>17} {ctx}")
+ else:
+ self.expand_frame(f"{'not a tensor':>17} {ctx}")
+
+ def batch_start_frame(self):
+ self.expand_frame(f"\n\n{self.prefix} *** Starting batch number={self.batch_number} ***")
+ self.expand_frame(f"{'abs min':8} {'abs max':8} metadata")
+
+ def batch_end_frame(self):
+ self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number - 1} ***\n\n")
+
+ def create_frame(self, module, input, output):
+ self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}")
+
+ # params
+ for name, p in module.named_parameters(recurse=False):
+ self.analyse_variable(p, name)
+
+ # inputs
+ if isinstance(input, tuple):
+ for i, x in enumerate(input):
+ self.analyse_variable(x, f"input[{i}]")
+ else:
+ self.analyse_variable(input, "input")
+
+ # outputs
+ if isinstance(output, tuple):
+ for i, x in enumerate(output):
+ # possibly a tuple of tuples
+ if isinstance(x, tuple):
+ for j, y in enumerate(x):
+ self.analyse_variable(y, f"output[{i}][{j}]")
+ else:
+ self.analyse_variable(x, f"output[{i}]")
+ else:
+ self.analyse_variable(output, "output")
+
+ self.save_frame()
+
+ def register_forward_hook(self):
+ self.model.apply(self._register_forward_hook)
+
+ def _register_forward_hook(self, module):
+ module.register_forward_hook(self.forward_hook)
+
+ def forward_hook(self, module, input, output):
+ # - input is a tuple of packed inputs (could be non-Tensors)
+ # - output could be a Tensor or a tuple of Tensors and non-Tensors
+
+ last_frame_of_batch = False
+
+ trace_mode = self.batch_number in self.trace_batch_nums
+ if trace_mode:
+ self.reset_saved_frames()
+
+ if self.total_calls == 0:
+ self.batch_start_frame()
+ self.total_calls += 1
+
+ # count batch numbers - the very first forward hook of the batch will be called when the
+ # batch completes - i.e. it gets called very last - we know this batch has finished
+ if module == self.model:
+ self.batch_number += 1
+ last_frame_of_batch = True
+
+ self.create_frame(module, input, output)
+
+ # if last_frame_of_batch:
+ # self.batch_end_frame()
+
+ if trace_mode:
+ self.trace_frames()
+
+ if last_frame_of_batch:
+ self.batch_start_frame()
+
+ if self.detected_overflow and not trace_mode:
+ self.dump_saved_frames()
+
+ # now we can abort, as it's pointless to continue running
+ raise ValueError(
+ "DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. "
+ "Please scroll up above this traceback to see the activation values prior to this event."
+ )
+
+ # abort after certain batch if requested to do so
+ if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num:
+ raise ValueError(
+ f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to"
+ f" `abort_after_batch_num={self.abort_after_batch_num}` arg"
+ )
+
+
+def get_abs_min_max(var, ctx):
+ abs_var = var.abs()
+ return f"{abs_var.min():8.2e} {abs_var.max():8.2e} {ctx}"
+
+
+def detect_overflow(var, ctx):
+ """
+ Report whether the tensor contains any `nan` or `inf` entries.
+
+ This is useful for detecting overflows/underflows and best to call right after the function that did some math that
+ modified the tensor in question.
+
+ This function contains a few other helper features that you can enable and tweak directly if you want to track
+ various other things.
+
+ Args:
+ var: the tensor variable to check
+ ctx: the message to print as a context
+
+ Return:
+ `True` if `inf` or `nan` was detected, `False` otherwise
+ """
+ detected = False
+ if torch.isnan(var).any().item():
+ detected = True
+ print(f"{ctx} has nans")
+ if torch.isinf(var).any().item():
+ detected = True
+ print(f"{ctx} has infs")
+
+ # if needed to monitor large elements can enable the following
+ if 0: # and detected:
+ n100 = var[torch.ge(var.abs(), 100)]
+ if n100.numel() > 0:
+ print(f"{ctx}: n100={n100.numel()}")
+ n1000 = var[torch.ge(var.abs(), 1000)]
+ if n1000.numel() > 0:
+ print(f"{ctx}: n1000={n1000.numel()}")
+ n10000 = var[torch.ge(var.abs(), 10000)]
+ if n10000.numel() > 0:
+ print(f"{ctx}: n10000={n10000.numel()}")
+
+ if 0:
+ print(f"min={var.min():9.2e} max={var.max():9.2e}")
+
+ if 0:
+ print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})")
+
+ return detected
+
+
+class DebugOption(ExplicitEnum):
+ UNDERFLOW_OVERFLOW = "underflow_overflow"
+ TPU_METRICS_DEBUG = "tpu_metrics_debug"
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/dependency_versions_table.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/dependency_versions_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..42bbcbaabfad340b2e08abe722dbdfe50639e40c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/dependency_versions_table.py
@@ -0,0 +1,108 @@
+# THIS FILE HAS BEEN AUTOGENERATED. To update:
+# 1. modify the `_deps` dict in setup.py
+# 2. run `make deps_table_update``
+deps = {
+ "Pillow": "Pillow>=10.0.1,<=15.0",
+ "accelerate": "accelerate>=0.26.0",
+ "av": "av",
+ "beautifulsoup4": "beautifulsoup4",
+ "blobfile": "blobfile",
+ "codecarbon": "codecarbon>=2.8.1",
+ "cookiecutter": "cookiecutter==1.7.3",
+ "dataclasses": "dataclasses",
+ "datasets": "datasets>=2.15.0",
+ "deepspeed": "deepspeed>=0.9.3",
+ "diffusers": "diffusers",
+ "dill": "dill<0.3.5",
+ "evaluate": "evaluate>=0.2.0",
+ "faiss-cpu": "faiss-cpu",
+ "fastapi": "fastapi",
+ "filelock": "filelock",
+ "flax": "flax>=0.4.1,<=0.7.0",
+ "ftfy": "ftfy",
+ "fugashi": "fugashi>=1.0",
+ "GitPython": "GitPython<3.1.19",
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
+ "hf_xet": "hf_xet",
+ "huggingface-hub": "huggingface-hub>=0.34.0,<1.0",
+ "importlib_metadata": "importlib_metadata",
+ "ipadic": "ipadic>=1.0.0,<2.0",
+ "jax": "jax>=0.4.1,<=0.4.13",
+ "jaxlib": "jaxlib>=0.4.1,<=0.4.13",
+ "jinja2": "jinja2>=3.1.0",
+ "kenlm": "kenlm",
+ "keras": "keras>2.9,<2.16",
+ "keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
+ "kernels": "kernels>=0.6.1,<=0.9",
+ "librosa": "librosa",
+ "natten": "natten>=0.14.6,<0.15.0",
+ "nltk": "nltk<=3.8.1",
+ "num2words": "num2words",
+ "numpy": "numpy>=1.17",
+ "onnxconverter-common": "onnxconverter-common",
+ "onnxruntime-tools": "onnxruntime-tools>=1.4.2",
+ "onnxruntime": "onnxruntime>=1.4.0",
+ "openai": "openai>=1.98.0",
+ "opencv-python": "opencv-python",
+ "optimum-benchmark": "optimum-benchmark>=0.3.0",
+ "optuna": "optuna",
+ "optax": "optax>=0.0.8,<=0.1.4",
+ "pandas": "pandas<2.3.0",
+ "packaging": "packaging>=20.0",
+ "parameterized": "parameterized>=0.9",
+ "phonemizer": "phonemizer",
+ "protobuf": "protobuf",
+ "psutil": "psutil",
+ "pyyaml": "pyyaml>=5.1",
+ "pydantic": "pydantic>=2",
+ "pytest": "pytest>=7.2.0",
+ "pytest-asyncio": "pytest-asyncio",
+ "pytest-rerunfailures": "pytest-rerunfailures<16.0",
+ "pytest-timeout": "pytest-timeout",
+ "pytest-xdist": "pytest-xdist",
+ "pytest-order": "pytest-order",
+ "python": "python>=3.9.0",
+ "ray[tune]": "ray[tune]>=2.7.0",
+ "regex": "regex!=2019.12.17",
+ "requests": "requests",
+ "rhoknp": "rhoknp>=1.1.0,<1.3.1",
+ "rjieba": "rjieba",
+ "rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
+ "ruff": "ruff==0.13.1",
+ "sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
+ "sacremoses": "sacremoses",
+ "safetensors": "safetensors>=0.4.3",
+ "sagemaker": "sagemaker>=2.31.0",
+ "schedulefree": "schedulefree>=1.2.6",
+ "scikit-learn": "scikit-learn",
+ "scipy": "scipy<1.13.0",
+ "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
+ "sigopt": "sigopt",
+ "starlette": "starlette",
+ "sudachipy": "sudachipy>=0.6.6",
+ "sudachidict_core": "sudachidict_core>=20220729",
+ "tensorboard": "tensorboard",
+ "tensorflow-cpu": "tensorflow-cpu>2.9,<2.16",
+ "tensorflow": "tensorflow>2.9,<2.16",
+ "tensorflow-text": "tensorflow-text<2.16",
+ "tensorflow-probability": "tensorflow-probability<0.24",
+ "tf2onnx": "tf2onnx",
+ "timeout-decorator": "timeout-decorator",
+ "tiktoken": "tiktoken",
+ "timm": "timm<=1.0.19,!=1.0.18",
+ "tokenizers": "tokenizers>=0.22.0,<=0.23.0",
+ "torch": "torch>=2.2",
+ "torchaudio": "torchaudio",
+ "torchvision": "torchvision",
+ "pyctcdecode": "pyctcdecode>=0.4.0",
+ "tqdm": "tqdm>=4.27",
+ "unidic": "unidic>=1.0.2",
+ "unidic_lite": "unidic_lite>=1.0.7",
+ "urllib3": "urllib3<2.0.0",
+ "uvicorn": "uvicorn",
+ "pytest-rich": "pytest-rich",
+ "libcst": "libcst",
+ "rich": "rich",
+ "opentelemetry-api": "opentelemetry-api",
+ "mistral-common[opencv]": "mistral-common[opencv]>=1.6.3",
+}
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/dynamic_module_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/dynamic_module_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d4e2bf48921ffdb5f94c02f966225d97efde05f
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/dynamic_module_utils.py
@@ -0,0 +1,843 @@
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities to dynamically load objects from the Hub."""
+
+import ast
+import filecmp
+import hashlib
+import importlib
+import importlib.metadata
+import importlib.util
+import keyword
+import os
+import re
+import shutil
+import signal
+import sys
+import threading
+import warnings
+from pathlib import Path
+from types import ModuleType
+from typing import Any, Optional, Union
+
+from huggingface_hub import try_to_load_from_cache
+from packaging import version
+
+from .utils import (
+ HF_MODULES_CACHE,
+ TRANSFORMERS_DYNAMIC_MODULE_NAME,
+ cached_file,
+ extract_commit_hash,
+ is_offline_mode,
+ logging,
+)
+from .utils.import_utils import VersionComparison, split_package_version
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def _sanitize_module_name(name: str) -> str:
+ r"""
+ Tries to sanitize a module name so that it can be used as a Python module.
+
+ The following transformations are applied:
+
+ 1. Replace `.` in module names with `_dot_`.
+ 2. Replace `-` in module names with `_hyphen_`.
+ 3. If the module name starts with a digit, prepend it with `_`.
+ 4. Warn if the sanitized name is a Python reserved keyword or not a valid identifier.
+
+ If the input name is already a valid identifier, it is returned unchanged.
+ """
+ # We not replacing `\W` characters with `_` to avoid collisions. Because `_` is a very common
+ # separator used in module names, replacing `\W` with `_` would create too many collisions.
+ # Once a module is imported, it is cached in `sys.modules` and the second import would return
+ # the first module, which might not be the expected behavior if name collisions happen.
+ new_name = name.replace(".", "_dot_").replace("-", "_hyphen_")
+ if new_name and new_name[0].isdigit():
+ new_name = f"_{new_name}"
+ if keyword.iskeyword(new_name):
+ logger.warning(
+ f"The module name {new_name} (originally {name}) is a reserved keyword in Python. "
+ "Please rename the original module to avoid import issues."
+ )
+ elif not new_name.isidentifier():
+ logger.warning(
+ f"The module name {new_name} (originally {name}) is not a valid Python identifier. "
+ "Please rename the original module to avoid import issues."
+ )
+ return new_name
+
+
+_HF_REMOTE_CODE_LOCK = threading.Lock()
+
+
+def init_hf_modules():
+ """
+ Creates the cache directory for modules with an init, and adds it to the Python path.
+ """
+ # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
+ if HF_MODULES_CACHE in sys.path:
+ return
+
+ sys.path.append(HF_MODULES_CACHE)
+ os.makedirs(HF_MODULES_CACHE, exist_ok=True)
+ init_path = Path(HF_MODULES_CACHE) / "__init__.py"
+ if not init_path.exists():
+ init_path.touch()
+ importlib.invalidate_caches()
+
+
+def create_dynamic_module(name: Union[str, os.PathLike]) -> None:
+ """
+ Creates a dynamic module in the cache directory for modules.
+
+ Args:
+ name (`str` or `os.PathLike`):
+ The name of the dynamic module to create.
+ """
+ init_hf_modules()
+ dynamic_module_path = (Path(HF_MODULES_CACHE) / name).resolve()
+ # If the parent module does not exist yet, recursively create it.
+ if not dynamic_module_path.parent.exists():
+ create_dynamic_module(dynamic_module_path.parent)
+ os.makedirs(dynamic_module_path, exist_ok=True)
+ init_path = dynamic_module_path / "__init__.py"
+ if not init_path.exists():
+ init_path.touch()
+ # It is extremely important to invalidate the cache when we change stuff in those modules, or users end up
+ # with errors about module that do not exist. Same for all other `invalidate_caches` in this file.
+ importlib.invalidate_caches()
+
+
+def get_relative_imports(module_file: Union[str, os.PathLike]) -> list[str]:
+ """
+ Get the list of modules that are relatively imported in a module file.
+
+ Args:
+ module_file (`str` or `os.PathLike`): The module file to inspect.
+
+ Returns:
+ `list[str]`: The list of relative imports in the module.
+ """
+ with open(module_file, encoding="utf-8") as f:
+ content = f.read()
+
+ # Imports of the form `import .xxx`
+ relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
+ # Imports of the form `from .xxx import yyy`
+ relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
+ # Unique-ify
+ return list(set(relative_imports))
+
+
+def get_relative_import_files(module_file: Union[str, os.PathLike]) -> list[str]:
+ """
+ Get the list of all files that are needed for a given module. Note that this function recurses through the relative
+ imports (if a imports b and b imports c, it will return module files for b and c).
+
+ Args:
+ module_file (`str` or `os.PathLike`): The module file to inspect.
+
+ Returns:
+ `list[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
+ of module files a given module needs.
+ """
+ no_change = False
+ files_to_check = [module_file]
+ all_relative_imports = []
+
+ # Let's recurse through all relative imports
+ while not no_change:
+ new_imports = []
+ for f in files_to_check:
+ new_imports.extend(get_relative_imports(f))
+
+ module_path = Path(module_file).parent
+ new_import_files = [f"{str(module_path / m)}.py" for m in new_imports]
+ files_to_check = [f for f in new_import_files if f not in all_relative_imports]
+
+ no_change = len(files_to_check) == 0
+ all_relative_imports.extend(files_to_check)
+
+ return all_relative_imports
+
+
+def get_imports(filename: Union[str, os.PathLike]) -> list[str]:
+ """
+ Extracts all the libraries (not relative imports this time) that are imported in a file.
+
+ Args:
+ filename (`str` or `os.PathLike`): The module file to inspect.
+
+ Returns:
+ `list[str]`: The list of all packages required to use the input module.
+ """
+ with open(filename, encoding="utf-8") as f:
+ content = f.read()
+ imported_modules = set()
+
+ import transformers.utils
+
+ def recursive_look_for_imports(node):
+ if isinstance(node, ast.Try):
+ return # Don't recurse into Try blocks and ignore imports in them
+ elif isinstance(node, ast.If):
+ test = node.test
+ for condition_node in ast.walk(test):
+ if isinstance(condition_node, ast.Call):
+ check_function = getattr(condition_node.func, "id", "")
+ if (
+ check_function.endswith("available")
+ and check_function.startswith("is_flash_attn")
+ or hasattr(transformers.utils.import_utils, check_function)
+ ):
+ # Don't recurse into "if flash_attn_available()" or any "if library_available" blocks
+ # that appears in `transformers.utils.import_utils` and ignore imports in them
+ return
+ elif isinstance(node, ast.Import):
+ # Handle 'import x' statements
+ for alias in node.names:
+ top_module = alias.name.split(".")[0]
+ if top_module:
+ imported_modules.add(top_module)
+ elif isinstance(node, ast.ImportFrom):
+ # Handle 'from x import y' statements, ignoring relative imports
+ if node.level == 0 and node.module:
+ top_module = node.module.split(".")[0]
+ if top_module:
+ imported_modules.add(top_module)
+
+ # Recursively visit all children
+ for child in ast.iter_child_nodes(node):
+ recursive_look_for_imports(child)
+
+ tree = ast.parse(content)
+ recursive_look_for_imports(tree)
+
+ return sorted(imported_modules)
+
+
+def check_imports(filename: Union[str, os.PathLike]) -> list[str]:
+ """
+ Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a
+ library is missing.
+
+ Args:
+ filename (`str` or `os.PathLike`): The module file to check.
+
+ Returns:
+ `list[str]`: The list of relative imports in the file.
+ """
+ imports = get_imports(filename)
+ missing_packages = []
+ for imp in imports:
+ try:
+ importlib.import_module(imp)
+ except ImportError as exception:
+ logger.warning(f"Encountered exception while importing {imp}: {exception}")
+ # Some packages can fail with an ImportError because of a dependency issue.
+ # This check avoids hiding such errors.
+ # See https://github.com/huggingface/transformers/issues/33604
+ if "No module named" in str(exception):
+ missing_packages.append(imp)
+ else:
+ raise
+
+ if len(missing_packages) > 0:
+ raise ImportError(
+ "This modeling file requires the following packages that were not found in your environment: "
+ f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
+ )
+
+ return get_relative_imports(filename)
+
+
+def get_class_in_module(
+ class_name: str,
+ module_path: Union[str, os.PathLike],
+ *,
+ force_reload: bool = False,
+) -> type:
+ """
+ Import a module on the cache directory for modules and extract a class from it.
+
+ Args:
+ class_name (`str`): The name of the class to import.
+ module_path (`str` or `os.PathLike`): The path to the module to import.
+ force_reload (`bool`, *optional*, defaults to `False`):
+ Whether to reload the dynamic module from file if it already exists in `sys.modules`.
+ Otherwise, the module is only reloaded if the file has changed.
+
+ Returns:
+ `typing.Type`: The class looked for.
+ """
+ name = os.path.normpath(module_path)
+ name = name.removesuffix(".py")
+ name = name.replace(os.path.sep, ".")
+ module_file: Path = Path(HF_MODULES_CACHE) / module_path
+ with _HF_REMOTE_CODE_LOCK:
+ if force_reload:
+ sys.modules.pop(name, None)
+ importlib.invalidate_caches()
+ cached_module: Optional[ModuleType] = sys.modules.get(name)
+ module_spec = importlib.util.spec_from_file_location(name, location=module_file)
+
+ # Hash the module file and all its relative imports to check if we need to reload it
+ module_files: list[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
+ module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
+
+ module: ModuleType
+ if cached_module is None:
+ module = importlib.util.module_from_spec(module_spec)
+ # insert it into sys.modules before any loading begins
+ sys.modules[name] = module
+ else:
+ module = cached_module
+ # reload in both cases, unless the module is already imported and the hash hits
+ if getattr(module, "__transformers_module_hash__", "") != module_hash:
+ module_spec.loader.exec_module(module)
+ module.__transformers_module_hash__ = module_hash
+ return getattr(module, class_name)
+
+
+def get_cached_module_file(
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ module_file: str,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: Optional[bool] = None,
+ proxies: Optional[dict[str, str]] = None,
+ token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+ repo_type: Optional[str] = None,
+ _commit_hash: Optional[str] = None,
+ **deprecated_kwargs,
+) -> str:
+ """
+ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
+ Transformers module.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ module_file (`str`):
+ The name of the module file containing the class to look for.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `hf auth login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+ repo_type (`str`, *optional*):
+ Specify the repo type (useful when downloading from a space for instance).
+
+
+
+ Passing `token=True` is required when you want to use a private model.
+
+
+
+ Returns:
+ `str`: The path to the module inside the cache.
+ """
+ use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
+ token = use_auth_token
+
+ if is_offline_mode() and not local_files_only:
+ logger.info("Offline mode: forcing local_files_only=True")
+ local_files_only = True
+
+ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if is_local:
+ submodule = _sanitize_module_name(os.path.basename(pretrained_model_name_or_path))
+ else:
+ submodule = os.path.sep.join(map(_sanitize_module_name, pretrained_model_name_or_path.split("/")))
+ cached_module = try_to_load_from_cache(
+ pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
+ )
+
+ new_files = []
+ try:
+ # Load from URL or cache if already cached
+ resolved_module_file = cached_file(
+ pretrained_model_name_or_path,
+ module_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ repo_type=repo_type,
+ _commit_hash=_commit_hash,
+ )
+ if not is_local and cached_module != resolved_module_file:
+ new_files.append(module_file)
+
+ except OSError:
+ logger.info(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
+ raise
+
+ # Check we have all the requirements in our environment
+ modules_needed = check_imports(resolved_module_file)
+
+ # Now we move the module inside our cached dynamic modules.
+ full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
+ create_dynamic_module(full_submodule)
+ submodule_path = Path(HF_MODULES_CACHE) / full_submodule
+ if submodule == _sanitize_module_name(os.path.basename(pretrained_model_name_or_path)):
+ # We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or
+ # has changed since last copy.
+ if not (submodule_path / module_file).exists() or not filecmp.cmp(
+ resolved_module_file, str(submodule_path / module_file)
+ ):
+ (submodule_path / module_file).parent.mkdir(parents=True, exist_ok=True)
+ shutil.copy(resolved_module_file, submodule_path / module_file)
+ importlib.invalidate_caches()
+ for module_needed in modules_needed:
+ module_needed = Path(module_file).parent / f"{module_needed}.py"
+ module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed)
+ if not (submodule_path / module_needed).exists() or not filecmp.cmp(
+ module_needed_file, str(submodule_path / module_needed)
+ ):
+ shutil.copy(module_needed_file, submodule_path / module_needed)
+ importlib.invalidate_caches()
+ else:
+ # Get the commit hash
+ commit_hash = extract_commit_hash(resolved_module_file, _commit_hash)
+
+ # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
+ # benefit of versioning.
+ submodule_path = submodule_path / commit_hash
+ full_submodule = full_submodule + os.path.sep + commit_hash
+ full_submodule_module_file_path = os.path.join(full_submodule, module_file)
+ create_dynamic_module(Path(full_submodule_module_file_path).parent)
+
+ if not (submodule_path / module_file).exists():
+ shutil.copy(resolved_module_file, submodule_path / module_file)
+ importlib.invalidate_caches()
+ # Make sure we also have every file with relative
+ for module_needed in modules_needed:
+ if not ((submodule_path / module_file).parent / f"{module_needed}.py").exists():
+ get_cached_module_file(
+ pretrained_model_name_or_path,
+ f"{Path(module_file).parent / module_needed}.py",
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ local_files_only=local_files_only,
+ _commit_hash=commit_hash,
+ )
+ new_files.append(f"{module_needed}.py")
+
+ if len(new_files) > 0 and revision is None:
+ new_files = "\n".join([f"- {f}" for f in new_files])
+ repo_type_str = "" if repo_type is None else f"{repo_type}s/"
+ url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}"
+ logger.warning(
+ f"A new version of the following files was downloaded from {url}:\n{new_files}"
+ "\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
+ "versions of the code file, you can pin a revision."
+ )
+
+ return os.path.join(full_submodule, module_file)
+
+
+def get_class_from_dynamic_module(
+ class_reference: str,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: Optional[bool] = None,
+ proxies: Optional[dict[str, str]] = None,
+ token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+ repo_type: Optional[str] = None,
+ code_revision: Optional[str] = None,
+ **kwargs,
+) -> type:
+ """
+ Extracts a class from a module file, present in the local folder or repository of a model.
+
+
+
+ Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
+ therefore only be called on trusted repos.
+
+
+
+
+
+ Args:
+ class_reference (`str`):
+ The full name of the class to load, including its module and optionally its repo.
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ This is used when `class_reference` does not specify another repo.
+ module_file (`str`):
+ The name of the module file containing the class to look for.
+ class_name (`str`):
+ The name of the class to import in the module.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `hf auth login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+ repo_type (`str`, *optional*):
+ Specify the repo type (useful when downloading from a space for instance).
+ code_revision (`str`, *optional*, defaults to `"main"`):
+ The specific revision to use for the code on the Hub, if the code leaves in a different repository than the
+ rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for
+ storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.
+
+
+
+ Passing `token=True` is required when you want to use a private model.
+
+
+
+ Returns:
+ `typing.Type`: The class, dynamically imported from the module.
+
+ Examples:
+
+ ```python
+ # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
+ # module.
+ cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model")
+
+ # Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this
+ # module.
+ cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model")
+ ```"""
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
+ token = use_auth_token
+
+ # Catch the name of the repo if it's specified in `class_reference`
+ if "--" in class_reference:
+ repo_id, class_reference = class_reference.split("--")
+ else:
+ repo_id = pretrained_model_name_or_path
+ module_file, class_name = class_reference.split(".")
+
+ if code_revision is None and pretrained_model_name_or_path == repo_id:
+ code_revision = revision
+ # And lastly we get the class inside our newly created module
+ final_module = get_cached_module_file(
+ repo_id,
+ module_file + ".py",
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ token=token,
+ revision=code_revision,
+ local_files_only=local_files_only,
+ repo_type=repo_type,
+ )
+ return get_class_in_module(class_name, final_module, force_reload=force_download)
+
+
+def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[dict] = None) -> list[str]:
+ """
+ Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally
+ adds the proper fields in a config.
+
+ Args:
+ obj (`Any`): The object for which to save the module files.
+ folder (`str` or `os.PathLike`): The folder where to save.
+ config (`PretrainedConfig` or dictionary, `optional`):
+ A config in which to register the auto_map corresponding to this custom object.
+
+ Returns:
+ `list[str]`: The list of files saved.
+ """
+ if obj.__module__ == "__main__":
+ logger.warning(
+ f"We can't save the code defining {obj} in {folder} as it's been defined in __main__. You should put "
+ "this code in a separate module so we can include it in the saved folder and make it easier to share via "
+ "the Hub."
+ )
+ return
+
+ def _set_auto_map_in_config(_config):
+ module_name = obj.__class__.__module__
+ last_module = module_name.split(".")[-1]
+ full_name = f"{last_module}.{obj.__class__.__name__}"
+ # Special handling for tokenizers
+ if "Tokenizer" in full_name:
+ slow_tokenizer_class = None
+ fast_tokenizer_class = None
+ if obj.__class__.__name__.endswith("Fast"):
+ # Fast tokenizer: we have the fast tokenizer class and we may have the slow one has an attribute.
+ fast_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
+ if getattr(obj, "slow_tokenizer_class", None) is not None:
+ slow_tokenizer = getattr(obj, "slow_tokenizer_class")
+ slow_tok_module_name = slow_tokenizer.__module__
+ last_slow_tok_module = slow_tok_module_name.split(".")[-1]
+ slow_tokenizer_class = f"{last_slow_tok_module}.{slow_tokenizer.__name__}"
+ else:
+ # Slow tokenizer: no way to have the fast class
+ slow_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
+
+ full_name = (slow_tokenizer_class, fast_tokenizer_class)
+
+ if isinstance(_config, dict):
+ auto_map = _config.get("auto_map", {})
+ auto_map[obj._auto_class] = full_name
+ _config["auto_map"] = auto_map
+ elif getattr(_config, "auto_map", None) is not None:
+ _config.auto_map[obj._auto_class] = full_name
+ else:
+ _config.auto_map = {obj._auto_class: full_name}
+
+ # Add object class to the config auto_map
+ if isinstance(config, (list, tuple)):
+ for cfg in config:
+ _set_auto_map_in_config(cfg)
+ elif config is not None:
+ _set_auto_map_in_config(config)
+
+ result = []
+ # Copy module file to the output folder.
+ object_file = sys.modules[obj.__module__].__file__
+ dest_file = Path(folder) / (Path(object_file).name)
+ shutil.copy(object_file, dest_file)
+ result.append(dest_file)
+
+ # Gather all relative imports recursively and make sure they are copied as well.
+ for needed_file in get_relative_import_files(object_file):
+ dest_file = Path(folder) / (Path(needed_file).name)
+ shutil.copy(needed_file, dest_file)
+ result.append(dest_file)
+
+ return result
+
+
+def _raise_timeout_error(signum, frame):
+ raise ValueError(
+ "Loading this model requires you to execute custom code contained in the model repository on your local "
+ "machine. Please set the option `trust_remote_code=True` to permit loading of this model."
+ )
+
+
+TIME_OUT_REMOTE_CODE = 15
+
+
+def resolve_trust_remote_code(
+ trust_remote_code, model_name, has_local_code, has_remote_code, error_message=None, upstream_repo=None
+):
+ """
+ Resolves the `trust_remote_code` argument. If there is remote code to be loaded, the user must opt-in to loading
+ it.
+
+ Args:
+ trust_remote_code (`bool` or `None`):
+ User-defined `trust_remote_code` value.
+ model_name (`str`):
+ The name of the model repository in huggingface.co.
+ has_local_code (`bool`):
+ Whether the model has local code.
+ has_remote_code (`bool`):
+ Whether the model has remote code.
+ error_message (`str`, *optional*):
+ Custom error message to display if there is remote code to load and the user didn't opt-in. If unset, the error
+ message will be regarding loading a model with custom code.
+
+ Returns:
+ The resolved `trust_remote_code` value.
+ """
+ if error_message is None:
+ if upstream_repo is not None:
+ error_message = (
+ f"The repository {model_name} references custom code contained in {upstream_repo} which "
+ f"must be executed to correctly load the model. You can inspect the repository "
+ f"content at https://hf.co/{upstream_repo} .\n"
+ )
+ elif os.path.isdir(model_name):
+ error_message = (
+ f"The repository {model_name} contains custom code which must be executed "
+ f"to correctly load the model. You can inspect the repository "
+ f"content at {os.path.abspath(model_name)} .\n"
+ )
+ else:
+ error_message = (
+ f"The repository {model_name} contains custom code which must be executed "
+ f"to correctly load the model. You can inspect the repository "
+ f"content at https://hf.co/{model_name} .\n"
+ )
+
+ if trust_remote_code is None:
+ if has_local_code:
+ trust_remote_code = False
+ elif has_remote_code and TIME_OUT_REMOTE_CODE > 0:
+ prev_sig_handler = None
+ try:
+ prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
+ signal.alarm(TIME_OUT_REMOTE_CODE)
+ while trust_remote_code is None:
+ answer = input(
+ f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n"
+ f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
+ f"Do you wish to run the custom code? [y/N] "
+ )
+ if answer.lower() in ["yes", "y", "1"]:
+ trust_remote_code = True
+ elif answer.lower() in ["no", "n", "0", ""]:
+ trust_remote_code = False
+ signal.alarm(0)
+ except Exception:
+ # OS which does not support signal.SIGALRM
+ raise ValueError(
+ f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n"
+ f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
+ )
+ finally:
+ if prev_sig_handler is not None:
+ signal.signal(signal.SIGALRM, prev_sig_handler)
+ signal.alarm(0)
+ elif has_remote_code:
+ # For the CI which puts the timeout at 0
+ _raise_timeout_error(None, None)
+
+ if has_remote_code and not has_local_code and not trust_remote_code:
+ raise ValueError(
+ f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n"
+ f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
+ )
+
+ return trust_remote_code
+
+
+def check_python_requirements(path_or_repo_id, requirements_file="requirements.txt", **kwargs):
+ """
+ Tries to locate `requirements_file` in a local folder or repo, and confirms that the environment has all the
+ python dependencies installed.
+
+ Args:
+ path_or_repo_id (`str` or `os.PathLike`):
+ This can be either:
+ - a string, the *model id* of a model repo on huggingface.co.
+ - a path to a *directory* potentially containing the file.
+ kwargs (`dict[str, Any]`, *optional*):
+ Additional arguments to pass to `cached_file`.
+ """
+ failed = [] # error messages regarding requirements
+ try:
+ requirements = cached_file(path_or_repo_id=path_or_repo_id, filename=requirements_file, **kwargs)
+ with open(requirements, "r") as f:
+ requirements = f.readlines()
+
+ for requirement in requirements:
+ requirement = requirement.strip()
+ if not requirement or requirement.startswith("#"): # skip empty lines and comments
+ continue
+
+ try:
+ # e.g. "torch>2.6.0" -> "torch", ">", "2.6.0"
+ package_name, delimiter, version_number = split_package_version(requirement)
+ except ValueError: # e.g. "torch", as opposed to "torch>2.6.0"
+ package_name = requirement
+ delimiter, version_number = None, None
+
+ try:
+ local_package_version = importlib.metadata.version(package_name)
+ except importlib.metadata.PackageNotFoundError:
+ failed.append(f"{requirement} (installed: None)")
+ continue
+
+ if delimiter is not None and version_number is not None:
+ is_satisfied = VersionComparison.from_string(delimiter)(
+ version.parse(local_package_version), version.parse(version_number)
+ )
+ else:
+ is_satisfied = True
+
+ if not is_satisfied:
+ failed.append(f"{requirement} (installed: {local_package_version})")
+
+ except OSError: # no requirements.txt
+ pass
+
+ if failed:
+ raise ImportError(
+ f"Missing requirements in your local environment for `{path_or_repo_id}`:\n" + "\n".join(failed)
+ )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/feature_extraction_sequence_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/feature_extraction_sequence_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0be17bd7d28cf61c70c177af4d025d0321617d7
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/feature_extraction_sequence_utils.py
@@ -0,0 +1,371 @@
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Sequence feature extraction class for common feature extractors to preprocess sequences.
+"""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
+from .utils import PaddingStrategy, TensorType, is_tf_tensor, is_torch_tensor, logging, to_numpy
+
+
+logger = logging.get_logger(__name__)
+
+
+class SequenceFeatureExtractor(FeatureExtractionMixin):
+ """
+ This is a general feature extraction class for speech recognition.
+
+ Args:
+ feature_size (`int`):
+ The feature dimension of the extracted features.
+ sampling_rate (`int`):
+ The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
+ padding_value (`float`):
+ The value that is used to fill the padding values / vectors.
+ """
+
+ def __init__(self, feature_size: int, sampling_rate: int, padding_value: float, **kwargs):
+ self.feature_size = feature_size
+ self.sampling_rate = sampling_rate
+ self.padding_value = padding_value
+
+ self.padding_side = kwargs.pop("padding_side", "right")
+ self.return_attention_mask = kwargs.pop("return_attention_mask", True)
+
+ super().__init__(**kwargs)
+
+ def pad(
+ self,
+ processed_features: Union[
+ BatchFeature,
+ list[BatchFeature],
+ dict[str, BatchFeature],
+ dict[str, list[BatchFeature]],
+ list[dict[str, BatchFeature]],
+ ],
+ padding: Union[bool, str, PaddingStrategy] = True,
+ max_length: Optional[int] = None,
+ truncation: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ ) -> BatchFeature:
+ """
+ Pad input values / input vectors or a batch of input values / input vectors up to predefined length or to the
+ max sequence length in the batch.
+
+ Padding side (left/right) padding values are defined at the feature extractor level (with `self.padding_side`,
+ `self.padding_value`)
+
+
+
+ If the `processed_features` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
+ result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of
+ PyTorch tensors, you will lose the specific device of your tensors however.
+
+
+
+ Args:
+ processed_features ([`BatchFeature`], list of [`BatchFeature`], `dict[str, list[float]]`, `dict[str, list[list[float]]` or `list[dict[str, list[float]]]`):
+ Processed inputs. Can represent one input ([`BatchFeature`] or `dict[str, list[float]]`) or a batch of
+ input values / vectors (list of [`BatchFeature`], *dict[str, list[list[float]]]* or *list[dict[str,
+ list[float]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader
+ collate function.
+
+ Instead of `list[float]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
+ see the note above for the return type.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
+ index) among:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see above).
+ truncation (`bool`):
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
+ pad_to_multiple_of (`int`, *optional*):
+ If set will pad the sequence to a multiple of the provided value.
+
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
+ return_attention_mask (`bool`, *optional*):
+ Whether to return the attention mask. If left to the default, will return the attention mask according
+ to the specific feature_extractor's default.
+
+ [What are attention masks?](../glossary#attention-mask)
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ """
+ # If we have a list of dicts, let's convert it in a dict of lists
+ # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
+ if isinstance(processed_features, (list, tuple)) and isinstance(processed_features[0], (dict, BatchFeature)):
+ processed_features = {
+ key: [example[key] for example in processed_features] for key in processed_features[0]
+ }
+
+ # The model's main input name, usually `input_values`, has be passed for padding
+ if self.model_input_names[0] not in processed_features:
+ raise ValueError(
+ "You should supply an instance of `transformers.BatchFeature` or list of `transformers.BatchFeature`"
+ f" to this method that includes {self.model_input_names[0]}, but you provided"
+ f" {list(processed_features.keys())}"
+ )
+
+ required_input = processed_features[self.model_input_names[0]]
+ return_attention_mask = (
+ return_attention_mask if return_attention_mask is not None else self.return_attention_mask
+ )
+
+ if len(required_input) == 0:
+ if return_attention_mask:
+ processed_features["attention_mask"] = []
+ return processed_features
+
+ # If we have PyTorch/TF tensors or lists as inputs, we cast them as Numpy arrays
+ # and rebuild them afterwards if no return_tensors is specified
+ # Note that we lose the specific device the tensor may be on for PyTorch
+
+ first_element = required_input[0]
+ if isinstance(first_element, (list, tuple)):
+ # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
+ index = 0
+ while len(required_input[index]) == 0:
+ index += 1
+ if index < len(required_input):
+ first_element = required_input[index][0]
+
+ if return_tensors is None:
+ if is_tf_tensor(first_element):
+ return_tensors = "tf"
+ elif is_torch_tensor(first_element):
+ return_tensors = "pt"
+ elif isinstance(first_element, (int, float, list, tuple, np.ndarray)):
+ return_tensors = "np"
+ else:
+ raise ValueError(
+ f"type of {first_element} unknown: {type(first_element)}. "
+ "Should be one of a python, numpy, pytorch or tensorflow object."
+ )
+
+ for key, value in processed_features.items():
+ if isinstance(value[0], (int, float)):
+ processed_features[key] = to_numpy(value)
+ else:
+ processed_features[key] = [to_numpy(v) for v in value]
+
+ # Convert padding_strategy in PaddingStrategy
+ padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length)
+
+ required_input = processed_features[self.model_input_names[0]]
+
+ batch_size = len(required_input)
+ if not all(len(v) == batch_size for v in processed_features.values()):
+ raise ValueError("Some items in the output dictionary have a different batch size than others.")
+
+ truncated_inputs = []
+ for i in range(batch_size):
+ inputs = {k: v[i] for k, v in processed_features.items()}
+ # truncation
+ inputs_slice = self._truncate(
+ inputs,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ truncation=truncation,
+ )
+ truncated_inputs.append(inputs_slice)
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ # make sure that `max_length` cannot be longer than the longest truncated length
+ max_length = max(len(input_slice[self.model_input_names[0]]) for input_slice in truncated_inputs)
+ padding_strategy = PaddingStrategy.MAX_LENGTH
+
+ batch_outputs = {}
+ for i in range(batch_size):
+ # padding
+ outputs = self._pad(
+ truncated_inputs[i],
+ max_length=max_length,
+ padding_strategy=padding_strategy,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ )
+
+ for key, value in outputs.items():
+ if key not in batch_outputs:
+ batch_outputs[key] = []
+ if value.dtype is np.dtype(np.float64):
+ value = value.astype(np.float32)
+ batch_outputs[key].append(value)
+
+ return BatchFeature(batch_outputs, tensor_type=return_tensors)
+
+ def _pad(
+ self,
+ processed_features: Union[dict[str, np.ndarray], BatchFeature],
+ max_length: Optional[int] = None,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ pad_to_multiple_of: Optional[int] = None,
+ return_attention_mask: Optional[bool] = None,
+ ) -> dict:
+ """
+ Pad inputs (on left/right and up to predefined length or max length in the batch)
+
+ Args:
+ processed_features (`Union[dict[str, np.ndarray], BatchFeature]`):
+ Dictionary of input values (`np.ndarray[float]`) / input vectors (`list[np.ndarray[float]]`) or batch
+ of inputs values (`list[np.ndarray[int]]`) / input vectors (`list[np.ndarray[int]]`)
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see below)
+ padding_strategy (`PaddingStrategy`, *optional*, default to `PaddingStrategy.DO_NOT_PAD`):
+ PaddingStrategy to use for padding.
+
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
+ The feature_extractor padding sides are defined in self.padding_side:
+
+ - 'left': pads on the left of the sequences
+ - 'right': pads on the right of the sequences
+ pad_to_multiple_of (`int`, *optional*):
+ Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to
+ enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta), or on TPUs
+ which benefit from having sequence lengths be a multiple of 128.
+ return_attention_mask (`bool`, *optional*):
+ Set to False to avoid returning attention mask (default: set to model specifics)
+ """
+ required_input = processed_features[self.model_input_names[0]]
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = len(required_input)
+
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) < max_length
+
+ if return_attention_mask and "attention_mask" not in processed_features:
+ processed_features["attention_mask"] = np.ones(len(required_input), dtype=np.int32)
+
+ if needs_to_be_padded:
+ difference = max_length - len(required_input)
+ if self.padding_side == "right":
+ if return_attention_mask:
+ processed_features["attention_mask"] = np.pad(
+ processed_features["attention_mask"], (0, difference)
+ )
+ padding_shape = ((0, difference), (0, 0)) if self.feature_size > 1 else (0, difference)
+ processed_features[self.model_input_names[0]] = np.pad(
+ required_input, padding_shape, "constant", constant_values=self.padding_value
+ )
+ elif self.padding_side == "left":
+ if return_attention_mask:
+ processed_features["attention_mask"] = np.pad(
+ processed_features["attention_mask"], (difference, 0)
+ )
+ padding_shape = ((difference, 0), (0, 0)) if self.feature_size > 1 else (difference, 0)
+ processed_features[self.model_input_names[0]] = np.pad(
+ required_input, padding_shape, "constant", constant_values=self.padding_value
+ )
+ else:
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
+
+ return processed_features
+
+ def _truncate(
+ self,
+ processed_features: Union[dict[str, np.ndarray], BatchFeature],
+ max_length: Optional[int] = None,
+ pad_to_multiple_of: Optional[int] = None,
+ truncation: Optional[bool] = None,
+ ):
+ """
+ Truncate inputs to predefined length or max length in the batch
+
+ Args:
+ processed_features(`Union[dict[str, np.ndarray], BatchFeature]`):
+ Dictionary of input values (`np.ndarray[float]`) / input vectors (`list[np.ndarray[float]]`) or batch
+ of inputs values (`list[np.ndarray[int]]`) / input vectors (`list[np.ndarray[int]]`)
+ max_length (`int`, *optional*):
+ maximum length of the returned list and optionally padding length (see below)
+ pad_to_multiple_of (`int`, *optional*) :
+ Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to
+ enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta), or on TPUs
+ which benefit from having sequence lengths be a multiple of 128.
+ truncation (`bool`, *optional*):
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
+ """
+ if not truncation:
+ return processed_features
+ elif truncation and max_length is None:
+ raise ValueError("When setting ``truncation=True``, make sure that ``max_length`` is defined.")
+
+ required_input = processed_features[self.model_input_names[0]]
+
+ # find `max_length` that fits `pad_to_multiple_of`
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ needs_to_be_truncated = len(required_input) > max_length
+
+ if needs_to_be_truncated:
+ processed_features[self.model_input_names[0]] = processed_features[self.model_input_names[0]][:max_length]
+ if "attention_mask" in processed_features:
+ processed_features["attention_mask"] = processed_features["attention_mask"][:max_length]
+
+ return processed_features
+
+ def _get_padding_strategies(self, padding=False, max_length=None):
+ """
+ Find the correct padding strategy
+ """
+
+ # Get padding strategy
+ if padding is not False:
+ if padding is True:
+ padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
+ elif not isinstance(padding, PaddingStrategy):
+ padding_strategy = PaddingStrategy(padding)
+ elif isinstance(padding, PaddingStrategy):
+ padding_strategy = padding
+ else:
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
+
+ # Set max length if needed
+ if max_length is None:
+ if padding_strategy == PaddingStrategy.MAX_LENGTH:
+ raise ValueError(
+ f"When setting ``padding={PaddingStrategy.MAX_LENGTH}``, make sure that max_length is defined"
+ )
+
+ # Test if we have a padding value
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.padding_value is None):
+ raise ValueError(
+ "Asking to pad but the feature_extractor does not have a padding value. Please select a value to use"
+ " as `padding_value`. For example: `feature_extractor.padding_value = 0.0`."
+ )
+
+ return padding_strategy
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/feature_extraction_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/feature_extraction_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e007e72d47612e377b3ce76c556d14e73a699095
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/feature_extraction_utils.py
@@ -0,0 +1,697 @@
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Feature extraction saving/loading class for common feature extractors.
+"""
+
+import copy
+import json
+import os
+import warnings
+from collections import UserDict
+from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
+
+import numpy as np
+
+from .dynamic_module_utils import custom_object_save
+from .utils import (
+ FEATURE_EXTRACTOR_NAME,
+ PROCESSOR_NAME,
+ PushToHubMixin,
+ TensorType,
+ copy_func,
+ download_url,
+ is_flax_available,
+ is_jax_tensor,
+ is_numpy_array,
+ is_offline_mode,
+ is_remote_url,
+ is_tf_available,
+ is_torch_available,
+ is_torch_device,
+ is_torch_dtype,
+ logging,
+ requires_backends,
+)
+from .utils.hub import cached_file
+
+
+if TYPE_CHECKING:
+ from .feature_extraction_sequence_utils import SequenceFeatureExtractor
+
+
+logger = logging.get_logger(__name__)
+
+PreTrainedFeatureExtractor = Union["SequenceFeatureExtractor"]
+
+# type hinting: specifying the type of feature extractor class that inherits from FeatureExtractionMixin
+SpecificFeatureExtractorType = TypeVar("SpecificFeatureExtractorType", bound="FeatureExtractionMixin")
+
+
+class BatchFeature(UserDict):
+ r"""
+ Holds the output of the [`~SequenceFeatureExtractor.pad`] and feature extractor specific `__call__` methods.
+
+ This class is derived from a python dictionary and can be used as a dictionary.
+
+ Args:
+ data (`dict`, *optional*):
+ Dictionary of lists/arrays/tensors returned by the __call__/pad methods ('input_values', 'attention_mask',
+ etc.).
+ tensor_type (`Union[None, str, TensorType]`, *optional*):
+ You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
+ initialization.
+ """
+
+ def __init__(self, data: Optional[dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
+ super().__init__(data)
+ self.convert_to_tensors(tensor_type=tensor_type)
+
+ def __getitem__(self, item: str) -> Any:
+ """
+ If the key is a string, returns the value of the dict associated to `key` ('input_values', 'attention_mask',
+ etc.).
+ """
+ if isinstance(item, str):
+ return self.data[item]
+ else:
+ raise KeyError("Indexing with integers is not available when using Python based feature extractors")
+
+ def __getattr__(self, item: str):
+ try:
+ return self.data[item]
+ except KeyError:
+ raise AttributeError
+
+ def __getstate__(self):
+ return {"data": self.data}
+
+ def __setstate__(self, state):
+ if "data" in state:
+ self.data = state["data"]
+
+ def _get_is_as_tensor_fns(self, tensor_type: Optional[Union[str, TensorType]] = None):
+ if tensor_type is None:
+ return None, None
+
+ # Convert to TensorType
+ if not isinstance(tensor_type, TensorType):
+ tensor_type = TensorType(tensor_type)
+
+ # Get a function reference for the correct framework
+ if tensor_type == TensorType.TENSORFLOW:
+ logger.warning_once(
+ "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We "
+ "recommend migrating to PyTorch classes or pinning your version of Transformers."
+ )
+ if not is_tf_available():
+ raise ImportError(
+ "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
+ )
+ import tensorflow as tf
+
+ as_tensor = tf.constant
+ is_tensor = tf.is_tensor
+ elif tensor_type == TensorType.PYTORCH:
+ if not is_torch_available():
+ raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
+ import torch
+
+ def as_tensor(value):
+ if isinstance(value, (list, tuple)) and len(value) > 0:
+ if isinstance(value[0], np.ndarray):
+ value = np.array(value)
+ elif (
+ isinstance(value[0], (list, tuple))
+ and len(value[0]) > 0
+ and isinstance(value[0][0], np.ndarray)
+ ):
+ value = np.array(value)
+ if isinstance(value, np.ndarray):
+ return torch.from_numpy(value)
+ else:
+ return torch.tensor(value)
+
+ is_tensor = torch.is_tensor
+ elif tensor_type == TensorType.JAX:
+ logger.warning_once(
+ "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We "
+ "recommend migrating to PyTorch classes or pinning your version of Transformers."
+ )
+ if not is_flax_available():
+ raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
+ import jax.numpy as jnp # noqa: F811
+
+ as_tensor = jnp.array
+ is_tensor = is_jax_tensor
+ else:
+
+ def as_tensor(value, dtype=None):
+ if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
+ value_lens = [len(val) for val in value]
+ if len(set(value_lens)) > 1 and dtype is None:
+ # we have a ragged list so handle explicitly
+ value = as_tensor([np.asarray(val) for val in value], dtype=object)
+ return np.asarray(value, dtype=dtype)
+
+ is_tensor = is_numpy_array
+ return is_tensor, as_tensor
+
+ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
+ """
+ Convert the inner content to tensors.
+
+ Args:
+ tensor_type (`str` or [`~utils.TensorType`], *optional*):
+ The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
+ `None`, no modification is done.
+ """
+ if tensor_type is None:
+ return self
+
+ is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
+
+ # Do the tensor conversion in batch
+ for key, value in self.items():
+ try:
+ if not is_tensor(value):
+ tensor = as_tensor(value)
+
+ self[key] = tensor
+ except: # noqa E722
+ if key == "overflowing_values":
+ raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
+ raise ValueError(
+ "Unable to create tensor, you should probably activate padding "
+ "with 'padding=True' to have batched tensors with the same length."
+ )
+
+ return self
+
+ def to(self, *args, **kwargs) -> "BatchFeature":
+ """
+ Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
+ different `dtypes` and sending the `BatchFeature` to a different `device`.
+
+ Args:
+ args (`Tuple`):
+ Will be passed to the `to(...)` function of the tensors.
+ kwargs (`Dict`, *optional*):
+ Will be passed to the `to(...)` function of the tensors.
+ To enable asynchronous data transfer, set the `non_blocking` flag in `kwargs` (defaults to `False`).
+
+ Returns:
+ [`BatchFeature`]: The same instance after modification.
+ """
+ requires_backends(self, ["torch"])
+ import torch
+
+ device = kwargs.get("device")
+ non_blocking = kwargs.get("non_blocking", False)
+ # Check if the args are a device or a dtype
+ if device is None and len(args) > 0:
+ # device should be always the first argument
+ arg = args[0]
+ if is_torch_dtype(arg):
+ # The first argument is a dtype
+ pass
+ elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
+ device = arg
+ else:
+ # it's something else
+ raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
+
+ # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
+ def maybe_to(v):
+ # check if v is a floating point
+ if isinstance(v, torch.Tensor) and torch.is_floating_point(v):
+ # cast and send to device
+ return v.to(*args, **kwargs)
+ elif isinstance(v, torch.Tensor) and device is not None:
+ return v.to(device=device, non_blocking=non_blocking)
+ else:
+ return v
+
+ self.data = {k: maybe_to(v) for k, v in self.items()}
+ return self
+
+
+class FeatureExtractionMixin(PushToHubMixin):
+ """
+ This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature
+ extractors.
+ """
+
+ _auto_class = None
+
+ def __init__(self, **kwargs):
+ """Set elements of `kwargs` as attributes."""
+ # Pop "processor_class" as it should be saved as private attribute
+ self._processor_class = kwargs.pop("processor_class", None)
+ # Additional attributes without default values
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ logger.error(f"Can't set {key} with value {value} for {self}")
+ raise err
+
+ def _set_processor_class(self, processor_class: str):
+ """Sets processor class as an attribute."""
+ self._processor_class = processor_class
+
+ @classmethod
+ def from_pretrained(
+ cls: type[SpecificFeatureExtractorType],
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ token: Optional[Union[str, bool]] = None,
+ revision: str = "main",
+ **kwargs,
+ ) -> SpecificFeatureExtractorType:
+ r"""
+ Instantiate a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a feature extractor, *e.g.* a
+ derived class of [`SequenceFeatureExtractor`].
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
+ huggingface.co.
+ - a path to a *directory* containing a feature extractor file saved using the
+ [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] method, e.g.,
+ `./my_model_directory/`.
+ - a path or url to a saved feature extractor JSON *file*, e.g.,
+ `./my_model_directory/preprocessor_config.json`.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the feature extractor files and override the cached versions
+ if they exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+ the token generated when running `hf auth login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+
+
+
+
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`.
+
+
+
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ If `False`, then this function returns just the final feature extractor object. If `True`, then this
+ functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
+ consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
+ `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
+ kwargs (`dict[str, Any]`, *optional*):
+ The values in kwargs of any keys which are feature extractor attributes will be used to override the
+ loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
+ controlled by the `return_unused_kwargs` keyword parameter.
+
+ Returns:
+ A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`].
+
+ Examples:
+
+ ```python
+ # We can't instantiate directly the base class *FeatureExtractionMixin* nor *SequenceFeatureExtractor* so let's show the examples on a
+ # derived class: *Wav2Vec2FeatureExtractor*
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
+ "facebook/wav2vec2-base-960h"
+ ) # Download feature_extraction_config from huggingface.co and cache.
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
+ "./test/saved_model/"
+ ) # E.g. feature_extractor (or model) was saved using *save_pretrained('./test/saved_model/')*
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("./test/saved_model/preprocessor_config.json")
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
+ "facebook/wav2vec2-base-960h", return_attention_mask=False, foo=False
+ )
+ assert feature_extractor.return_attention_mask is False
+ feature_extractor, unused_kwargs = Wav2Vec2FeatureExtractor.from_pretrained(
+ "facebook/wav2vec2-base-960h", return_attention_mask=False, foo=False, return_unused_kwargs=True
+ )
+ assert feature_extractor.return_attention_mask is False
+ assert unused_kwargs == {"foo": False}
+ ```"""
+ kwargs["cache_dir"] = cache_dir
+ kwargs["force_download"] = force_download
+ kwargs["local_files_only"] = local_files_only
+ kwargs["revision"] = revision
+
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ if token is not None:
+ kwargs["token"] = token
+
+ feature_extractor_dict, kwargs = cls.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
+
+ return cls.from_dict(feature_extractor_dict, **kwargs)
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save a feature_extractor object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the feature extractor JSON file will be saved (will be created if it does not exist).
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ kwargs (`dict[str, Any]`, *optional*):
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ """
+ use_auth_token = kwargs.pop("use_auth_token", None)
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if kwargs.get("token") is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ kwargs["token"] = use_auth_token
+
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ if push_to_hub:
+ commit_message = kwargs.pop("commit_message", None)
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+ repo_id = self._create_repo(repo_id, **kwargs)
+ files_timestamps = self._get_files_timestamps(save_directory)
+
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
+ # loaded from the Hub.
+ if self._auto_class is not None:
+ custom_object_save(self, save_directory, config=self)
+
+ # If we save using the predefined names, we can load using `from_pretrained`
+ output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME)
+
+ self.to_json_file(output_feature_extractor_file)
+ logger.info(f"Feature extractor saved in {output_feature_extractor_file}")
+
+ if push_to_hub:
+ self._upload_modified_files(
+ save_directory,
+ repo_id,
+ files_timestamps,
+ commit_message=commit_message,
+ token=kwargs.get("token"),
+ )
+
+ return [output_feature_extractor_file]
+
+ @classmethod
+ def get_feature_extractor_dict(
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
+ """
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
+ feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] using `from_dict`.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
+
+ Returns:
+ `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the feature extractor object.
+ """
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", None)
+ proxies = kwargs.pop("proxies", None)
+ subfolder = kwargs.pop("subfolder", None)
+ token = kwargs.pop("token", None)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ from_pipeline = kwargs.pop("_from_pipeline", None)
+ from_auto_class = kwargs.pop("_from_auto", False)
+
+ user_agent = {"file_type": "feature extractor", "from_auto_class": from_auto_class}
+ if from_pipeline is not None:
+ user_agent["using_pipeline"] = from_pipeline
+
+ if is_offline_mode() and not local_files_only:
+ logger.info("Offline mode: forcing local_files_only=True")
+ local_files_only = True
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if os.path.isdir(pretrained_model_name_or_path):
+ feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
+ if os.path.isfile(pretrained_model_name_or_path):
+ resolved_feature_extractor_file = pretrained_model_name_or_path
+ is_local = True
+ elif is_remote_url(pretrained_model_name_or_path):
+ feature_extractor_file = pretrained_model_name_or_path
+ resolved_feature_extractor_file = download_url(pretrained_model_name_or_path)
+ else:
+ feature_extractor_file = FEATURE_EXTRACTOR_NAME
+ try:
+ # Load from local folder or from cache or download from model Hub and cache
+ resolved_feature_extractor_files = [
+ resolved_file
+ for filename in [feature_extractor_file, PROCESSOR_NAME]
+ if (
+ resolved_file := cached_file(
+ pretrained_model_name_or_path,
+ filename=filename,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ subfolder=subfolder,
+ token=token,
+ user_agent=user_agent,
+ revision=revision,
+ _raise_exceptions_for_missing_entries=False,
+ )
+ )
+ is not None
+ ]
+ resolved_feature_extractor_file = resolved_feature_extractor_files[0]
+ except OSError:
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
+ # the original exception.
+ raise
+ except Exception:
+ # For any other exception, we throw a generic error.
+ raise OSError(
+ f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
+ " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+ f" directory containing a {FEATURE_EXTRACTOR_NAME} file"
+ )
+
+ try:
+ # Load feature_extractor dict
+ with open(resolved_feature_extractor_file, encoding="utf-8") as reader:
+ text = reader.read()
+ feature_extractor_dict = json.loads(text)
+ feature_extractor_dict = feature_extractor_dict.get("feature_extractor", feature_extractor_dict)
+
+ except json.JSONDecodeError:
+ raise OSError(
+ f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
+ )
+
+ if is_local:
+ logger.info(f"loading configuration file {resolved_feature_extractor_file}")
+ else:
+ logger.info(
+ f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}"
+ )
+
+ return feature_extractor_dict, kwargs
+
+ @classmethod
+ def from_dict(
+ cls, feature_extractor_dict: dict[str, Any], **kwargs
+ ) -> Union["FeatureExtractionMixin", tuple["FeatureExtractionMixin", dict[str, Any]]]:
+ """
+ Instantiates a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a Python dictionary of
+ parameters.
+
+ Args:
+ feature_extractor_dict (`dict[str, Any]`):
+ Dictionary that will be used to instantiate the feature extractor object. Such a dictionary can be
+ retrieved from a pretrained checkpoint by leveraging the
+ [`~feature_extraction_utils.FeatureExtractionMixin.to_dict`] method.
+ kwargs (`dict[str, Any]`):
+ Additional parameters from which to initialize the feature extractor object.
+
+ Returns:
+ [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature extractor object instantiated from those
+ parameters.
+ """
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
+
+ # Update feature_extractor with kwargs if needed
+ to_remove = []
+ for key, value in kwargs.items():
+ if key in feature_extractor_dict:
+ feature_extractor_dict[key] = value
+ to_remove.append(key)
+ for key in to_remove:
+ kwargs.pop(key, None)
+
+ feature_extractor = cls(**feature_extractor_dict)
+
+ logger.info(f"Feature extractor {feature_extractor}")
+ if return_unused_kwargs:
+ return feature_extractor, kwargs
+ else:
+ return feature_extractor
+
+ def to_dict(self) -> dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary. Returns:
+ `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["feature_extractor_type"] = self.__class__.__name__
+ if "mel_filters" in output:
+ del output["mel_filters"]
+ if "window" in output:
+ del output["window"]
+ return output
+
+ @classmethod
+ def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "FeatureExtractionMixin":
+ """
+ Instantiates a feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] from the path to
+ a JSON file of parameters.
+
+ Args:
+ json_file (`str` or `os.PathLike`):
+ Path to the JSON file containing the parameters.
+
+ Returns:
+ A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature_extractor
+ object instantiated from that JSON file.
+ """
+ with open(json_file, encoding="utf-8") as reader:
+ text = reader.read()
+ feature_extractor_dict = json.loads(text)
+ return cls(**feature_extractor_dict)
+
+ def to_json_string(self) -> str:
+ """
+ Serializes this instance to a JSON string.
+
+ Returns:
+ `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
+ """
+ dictionary = self.to_dict()
+
+ for key, value in dictionary.items():
+ if isinstance(value, np.ndarray):
+ dictionary[key] = value.tolist()
+
+ # make sure private name "_processor_class" is correctly
+ # saved as "processor_class"
+ _processor_class = dictionary.pop("_processor_class", None)
+ if _processor_class is not None:
+ dictionary["processor_class"] = _processor_class
+
+ return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this feature_extractor instance's parameters will be saved.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
+
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ @classmethod
+ def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"):
+ """
+ Register this class with a given auto class. This should only be used for custom feature extractors as the ones
+ in the library are already mapped with `AutoFeatureExtractor`.
+
+
+
+ Args:
+ auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`):
+ The auto class to register this new feature extractor with.
+ """
+ if not isinstance(auto_class, str):
+ auto_class = auto_class.__name__
+
+ import transformers.models.auto as auto_module
+
+ if not hasattr(auto_module, auto_class):
+ raise ValueError(f"{auto_class} is not a valid auto class.")
+
+ cls._auto_class = auto_class
+
+
+FeatureExtractionMixin.push_to_hub = copy_func(FeatureExtractionMixin.push_to_hub)
+if FeatureExtractionMixin.push_to_hub.__doc__ is not None:
+ FeatureExtractionMixin.push_to_hub.__doc__ = FeatureExtractionMixin.push_to_hub.__doc__.format(
+ object="feature extractor", object_class="AutoFeatureExtractor", object_files="feature extractor file"
+ )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/file_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/file_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc6f722262d99e0123423e66d2a5db3edd118ff5
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/file_utils.py
@@ -0,0 +1,130 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+File utilities: utilities related to download and cache models
+
+This module should not be update anymore and is only left for backward compatibility.
+"""
+
+from huggingface_hub import get_full_repo_name # for backward compatibility
+from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY # for backward compatibility
+
+from . import __version__
+
+# Backward compatibility imports, to make sure all those objects can be found in file_utils
+from .utils import (
+ CLOUDFRONT_DISTRIB_PREFIX,
+ CONFIG_NAME,
+ DUMMY_INPUTS,
+ DUMMY_MASK,
+ ENV_VARS_TRUE_AND_AUTO_VALUES,
+ ENV_VARS_TRUE_VALUES,
+ FEATURE_EXTRACTOR_NAME,
+ FLAX_WEIGHTS_NAME,
+ HF_MODULES_CACHE,
+ HUGGINGFACE_CO_PREFIX,
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
+ MODEL_CARD_NAME,
+ MULTIPLE_CHOICE_DUMMY_INPUTS,
+ PYTORCH_PRETRAINED_BERT_CACHE,
+ PYTORCH_TRANSFORMERS_CACHE,
+ S3_BUCKET_PREFIX,
+ SENTENCEPIECE_UNDERLINE,
+ SPIECE_UNDERLINE,
+ TF2_WEIGHTS_NAME,
+ TF_WEIGHTS_NAME,
+ TORCH_FX_REQUIRED_VERSION,
+ TRANSFORMERS_CACHE,
+ TRANSFORMERS_DYNAMIC_MODULE_NAME,
+ USE_JAX,
+ USE_TF,
+ USE_TORCH,
+ WEIGHTS_INDEX_NAME,
+ WEIGHTS_NAME,
+ ContextManagers,
+ DummyObject,
+ EntryNotFoundError,
+ ExplicitEnum,
+ ModelOutput,
+ PaddingStrategy,
+ PushToHubMixin,
+ RepositoryNotFoundError,
+ RevisionNotFoundError,
+ TensorType,
+ _LazyModule,
+ add_code_sample_docstrings,
+ add_end_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ copy_func,
+ default_cache_path,
+ define_sagemaker_information,
+ get_torch_version,
+ has_file,
+ http_user_agent,
+ is_apex_available,
+ is_bs4_available,
+ is_coloredlogs_available,
+ is_datasets_available,
+ is_detectron2_available,
+ is_faiss_available,
+ is_flax_available,
+ is_ftfy_available,
+ is_g2p_en_available,
+ is_in_notebook,
+ is_ipex_available,
+ is_librosa_available,
+ is_offline_mode,
+ is_onnx_available,
+ is_pandas_available,
+ is_phonemizer_available,
+ is_protobuf_available,
+ is_psutil_available,
+ is_py3nvml_available,
+ is_pyctcdecode_available,
+ is_pytesseract_available,
+ is_pytorch_quantization_available,
+ is_rjieba_available,
+ is_sagemaker_dp_enabled,
+ is_sagemaker_mp_enabled,
+ is_scipy_available,
+ is_sentencepiece_available,
+ is_seqio_available,
+ is_sklearn_available,
+ is_soundfile_available,
+ is_spacy_available,
+ is_speech_available,
+ is_tensor,
+ is_tensorflow_probability_available,
+ is_tf2onnx_available,
+ is_tf_available,
+ is_timm_available,
+ is_tokenizers_available,
+ is_torch_available,
+ is_torch_bf16_available,
+ is_torch_cuda_available,
+ is_torch_fx_available,
+ is_torch_fx_proxy,
+ is_torch_mps_available,
+ is_torch_tf32_available,
+ is_torch_xla_available,
+ is_torchaudio_available,
+ is_training_run_on_sagemaker,
+ is_vision_available,
+ replace_return_docstrings,
+ requires_backends,
+ to_numpy,
+ to_py_obj,
+ torch_only_method,
+)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/hyperparameter_search.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/hyperparameter_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8558ceed32f4640c727a192d3ab00ed9104986f
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/hyperparameter_search.py
@@ -0,0 +1,141 @@
+# Copyright 2023-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+from .integrations import (
+ is_optuna_available,
+ is_ray_tune_available,
+ is_sigopt_available,
+ is_wandb_available,
+ run_hp_search_optuna,
+ run_hp_search_ray,
+ run_hp_search_sigopt,
+ run_hp_search_wandb,
+)
+from .trainer_utils import (
+ HPSearchBackend,
+ default_hp_space_optuna,
+ default_hp_space_ray,
+ default_hp_space_sigopt,
+ default_hp_space_wandb,
+)
+from .utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class HyperParamSearchBackendBase:
+ name: str
+ pip_package: Optional[str] = None
+
+ @staticmethod
+ def is_available():
+ raise NotImplementedError
+
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
+ raise NotImplementedError
+
+ def default_hp_space(self, trial):
+ raise NotImplementedError
+
+ def ensure_available(self):
+ if not self.is_available():
+ raise RuntimeError(
+ f"You picked the {self.name} backend, but it is not installed. Run {self.pip_install()}."
+ )
+
+ @classmethod
+ def pip_install(cls):
+ return f"`pip install {cls.pip_package or cls.name}`"
+
+
+class OptunaBackend(HyperParamSearchBackendBase):
+ name = "optuna"
+
+ @staticmethod
+ def is_available():
+ return is_optuna_available()
+
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
+ return run_hp_search_optuna(trainer, n_trials, direction, **kwargs)
+
+ def default_hp_space(self, trial):
+ return default_hp_space_optuna(trial)
+
+
+class RayTuneBackend(HyperParamSearchBackendBase):
+ name = "ray"
+ pip_package = "'ray[tune]'"
+
+ @staticmethod
+ def is_available():
+ return is_ray_tune_available()
+
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
+ return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
+
+ def default_hp_space(self, trial):
+ return default_hp_space_ray(trial)
+
+
+class SigOptBackend(HyperParamSearchBackendBase):
+ name = "sigopt"
+
+ @staticmethod
+ def is_available():
+ return is_sigopt_available()
+
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
+ return run_hp_search_sigopt(trainer, n_trials, direction, **kwargs)
+
+ def default_hp_space(self, trial):
+ return default_hp_space_sigopt(trial)
+
+
+class WandbBackend(HyperParamSearchBackendBase):
+ name = "wandb"
+
+ @staticmethod
+ def is_available():
+ return is_wandb_available()
+
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
+ return run_hp_search_wandb(trainer, n_trials, direction, **kwargs)
+
+ def default_hp_space(self, trial):
+ return default_hp_space_wandb(trial)
+
+
+ALL_HYPERPARAMETER_SEARCH_BACKENDS = {
+ HPSearchBackend(backend.name): backend for backend in [OptunaBackend, RayTuneBackend, SigOptBackend, WandbBackend]
+}
+
+
+def default_hp_search_backend() -> str:
+ available_backends = [backend for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() if backend.is_available()]
+ if len(available_backends) > 0:
+ name = available_backends[0].name
+ if len(available_backends) > 1:
+ logger.info(
+ f"{len(available_backends)} hyperparameter search backends available. Using {name} as the default."
+ )
+ return name
+ raise RuntimeError(
+ "No hyperparameter search backend available.\n"
+ + "\n".join(
+ f" - To install {backend.name} run {backend.pip_install()}"
+ for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values()
+ )
+ )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/image_processing_base.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/image_processing_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfe94ffd0df794588b91544ffe9764ee5717c327
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/image_processing_base.py
@@ -0,0 +1,543 @@
+# Copyright 2020 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import copy
+import json
+import os
+import warnings
+from typing import Any, Optional, TypeVar, Union
+
+import numpy as np
+
+from .dynamic_module_utils import custom_object_save
+from .feature_extraction_utils import BatchFeature as BaseBatchFeature
+from .image_utils import is_valid_image, load_image
+from .utils import (
+ IMAGE_PROCESSOR_NAME,
+ PROCESSOR_NAME,
+ PushToHubMixin,
+ copy_func,
+ download_url,
+ is_offline_mode,
+ is_remote_url,
+ logging,
+)
+from .utils.hub import cached_file
+
+
+ImageProcessorType = TypeVar("ImageProcessorType", bound="ImageProcessingMixin")
+
+
+logger = logging.get_logger(__name__)
+
+
+# TODO: Move BatchFeature to be imported by both image_processing_utils and image_processing_utils_fast
+# We override the class string here, but logic is the same.
+class BatchFeature(BaseBatchFeature):
+ r"""
+ Holds the output of the image processor specific `__call__` methods.
+
+ This class is derived from a python dictionary and can be used as a dictionary.
+
+ Args:
+ data (`dict`):
+ Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.).
+ tensor_type (`Union[None, str, TensorType]`, *optional*):
+ You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
+ initialization.
+ """
+
+
+# TODO: (Amy) - factor out the common parts of this and the feature extractor
+class ImageProcessingMixin(PushToHubMixin):
+ """
+ This is an image processor mixin used to provide saving/loading functionality for sequential and image feature
+ extractors.
+ """
+
+ _auto_class = None
+
+ def __init__(self, **kwargs):
+ """Set elements of `kwargs` as attributes."""
+ # This key was saved while we still used `XXXFeatureExtractor` for image processing. Now we use
+ # `XXXImageProcessor`, this attribute and its value are misleading.
+ kwargs.pop("feature_extractor_type", None)
+ # Pop "processor_class" as it should be saved as private attribute
+ self._processor_class = kwargs.pop("processor_class", None)
+ # Additional attributes without default values
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ logger.error(f"Can't set {key} with value {value} for {self}")
+ raise err
+
+ def _set_processor_class(self, processor_class: str):
+ """Sets processor class as an attribute."""
+ self._processor_class = processor_class
+
+ @classmethod
+ def from_pretrained(
+ cls: type[ImageProcessorType],
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ token: Optional[Union[str, bool]] = None,
+ revision: str = "main",
+ **kwargs,
+ ) -> ImageProcessorType:
+ r"""
+ Instantiate a type of [`~image_processing_utils.ImageProcessingMixin`] from an image processor.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained image_processor hosted inside a model repo on
+ huggingface.co.
+ - a path to a *directory* containing a image processor file saved using the
+ [`~image_processing_utils.ImageProcessingMixin.save_pretrained`] method, e.g.,
+ `./my_model_directory/`.
+ - a path or url to a saved image processor JSON *file*, e.g.,
+ `./my_model_directory/preprocessor_config.json`.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model image processor should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the image processor files and override the cached versions if
+ they exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+ the token generated when running `hf auth login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+
+
+
+
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`.
+
+
+
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ If `False`, then this function returns just the final image processor object. If `True`, then this
+ functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
+ consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of
+ `kwargs` which has not been used to update `image_processor` and is otherwise ignored.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
+ kwargs (`dict[str, Any]`, *optional*):
+ The values in kwargs of any keys which are image processor attributes will be used to override the
+ loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
+ controlled by the `return_unused_kwargs` keyword parameter.
+
+ Returns:
+ A image processor of type [`~image_processing_utils.ImageProcessingMixin`].
+
+ Examples:
+
+ ```python
+ # We can't instantiate directly the base class *ImageProcessingMixin* so let's show the examples on a
+ # derived class: *CLIPImageProcessor*
+ image_processor = CLIPImageProcessor.from_pretrained(
+ "openai/clip-vit-base-patch32"
+ ) # Download image_processing_config from huggingface.co and cache.
+ image_processor = CLIPImageProcessor.from_pretrained(
+ "./test/saved_model/"
+ ) # E.g. image processor (or model) was saved using *save_pretrained('./test/saved_model/')*
+ image_processor = CLIPImageProcessor.from_pretrained("./test/saved_model/preprocessor_config.json")
+ image_processor = CLIPImageProcessor.from_pretrained(
+ "openai/clip-vit-base-patch32", do_normalize=False, foo=False
+ )
+ assert image_processor.do_normalize is False
+ image_processor, unused_kwargs = CLIPImageProcessor.from_pretrained(
+ "openai/clip-vit-base-patch32", do_normalize=False, foo=False, return_unused_kwargs=True
+ )
+ assert image_processor.do_normalize is False
+ assert unused_kwargs == {"foo": False}
+ ```"""
+ kwargs["cache_dir"] = cache_dir
+ kwargs["force_download"] = force_download
+ kwargs["local_files_only"] = local_files_only
+ kwargs["revision"] = revision
+
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ if token is not None:
+ kwargs["token"] = token
+
+ image_processor_dict, kwargs = cls.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
+
+ return cls.from_dict(image_processor_dict, **kwargs)
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save an image processor object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~image_processing_utils.ImageProcessingMixin.from_pretrained`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the image processor JSON file will be saved (will be created if it does not exist).
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ kwargs (`dict[str, Any]`, *optional*):
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ """
+ use_auth_token = kwargs.pop("use_auth_token", None)
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if kwargs.get("token") is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ kwargs["token"] = use_auth_token
+
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ if push_to_hub:
+ commit_message = kwargs.pop("commit_message", None)
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+ repo_id = self._create_repo(repo_id, **kwargs)
+ files_timestamps = self._get_files_timestamps(save_directory)
+
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
+ # loaded from the Hub.
+ if self._auto_class is not None:
+ custom_object_save(self, save_directory, config=self)
+
+ # If we save using the predefined names, we can load using `from_pretrained`
+ output_image_processor_file = os.path.join(save_directory, IMAGE_PROCESSOR_NAME)
+
+ self.to_json_file(output_image_processor_file)
+ logger.info(f"Image processor saved in {output_image_processor_file}")
+
+ if push_to_hub:
+ self._upload_modified_files(
+ save_directory,
+ repo_id,
+ files_timestamps,
+ commit_message=commit_message,
+ token=kwargs.get("token"),
+ )
+
+ return [output_image_processor_file]
+
+ @classmethod
+ def get_image_processor_dict(
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
+ """
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
+ image processor of type [`~image_processor_utils.ImageProcessingMixin`] using `from_dict`.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
+ image_processor_filename (`str`, *optional*, defaults to `"config.json"`):
+ The name of the file in the model directory to use for the image processor config.
+
+ Returns:
+ `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object.
+ """
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", None)
+ proxies = kwargs.pop("proxies", None)
+ token = kwargs.pop("token", None)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", "")
+ image_processor_filename = kwargs.pop("image_processor_filename", IMAGE_PROCESSOR_NAME)
+
+ from_pipeline = kwargs.pop("_from_pipeline", None)
+ from_auto_class = kwargs.pop("_from_auto", False)
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ user_agent = {"file_type": "image processor", "from_auto_class": from_auto_class}
+ if from_pipeline is not None:
+ user_agent["using_pipeline"] = from_pipeline
+
+ if is_offline_mode() and not local_files_only:
+ logger.info("Offline mode: forcing local_files_only=True")
+ local_files_only = True
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if os.path.isdir(pretrained_model_name_or_path):
+ image_processor_file = os.path.join(pretrained_model_name_or_path, image_processor_filename)
+ if os.path.isfile(pretrained_model_name_or_path):
+ resolved_image_processor_file = pretrained_model_name_or_path
+ is_local = True
+ elif is_remote_url(pretrained_model_name_or_path):
+ image_processor_file = pretrained_model_name_or_path
+ resolved_image_processor_file = download_url(pretrained_model_name_or_path)
+ else:
+ image_processor_file = image_processor_filename
+ try:
+ # Load from local folder or from cache or download from model Hub and cache
+ resolved_image_processor_files = [
+ resolved_file
+ for filename in [image_processor_file, PROCESSOR_NAME]
+ if (
+ resolved_file := cached_file(
+ pretrained_model_name_or_path,
+ filename=filename,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ token=token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ _raise_exceptions_for_missing_entries=False,
+ )
+ )
+ is not None
+ ]
+ resolved_image_processor_file = resolved_image_processor_files[0]
+ except OSError:
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
+ # the original exception.
+ raise
+ except Exception:
+ # For any other exception, we throw a generic error.
+ raise OSError(
+ f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load"
+ " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+ f" directory containing a {image_processor_filename} file"
+ )
+
+ try:
+ # Load image_processor dict
+ with open(resolved_image_processor_file, encoding="utf-8") as reader:
+ text = reader.read()
+ image_processor_dict = json.loads(text)
+ image_processor_dict = image_processor_dict.get("image_processor", image_processor_dict)
+
+ except json.JSONDecodeError:
+ raise OSError(
+ f"It looks like the config file at '{resolved_image_processor_file}' is not a valid JSON file."
+ )
+
+ if is_local:
+ logger.info(f"loading configuration file {resolved_image_processor_file}")
+ else:
+ logger.info(
+ f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}"
+ )
+
+ return image_processor_dict, kwargs
+
+ @classmethod
+ def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs):
+ """
+ Instantiates a type of [`~image_processing_utils.ImageProcessingMixin`] from a Python dictionary of parameters.
+
+ Args:
+ image_processor_dict (`dict[str, Any]`):
+ Dictionary that will be used to instantiate the image processor object. Such a dictionary can be
+ retrieved from a pretrained checkpoint by leveraging the
+ [`~image_processing_utils.ImageProcessingMixin.to_dict`] method.
+ kwargs (`dict[str, Any]`):
+ Additional parameters from which to initialize the image processor object.
+
+ Returns:
+ [`~image_processing_utils.ImageProcessingMixin`]: The image processor object instantiated from those
+ parameters.
+ """
+ image_processor_dict = image_processor_dict.copy()
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
+
+ # The `size` parameter is a dict and was previously an int or tuple in feature extractors.
+ # We set `size` here directly to the `image_processor_dict` so that it is converted to the appropriate
+ # dict within the image processor and isn't overwritten if `size` is passed in as a kwarg.
+ if "size" in kwargs and "size" in image_processor_dict:
+ image_processor_dict["size"] = kwargs.pop("size")
+ if "crop_size" in kwargs and "crop_size" in image_processor_dict:
+ image_processor_dict["crop_size"] = kwargs.pop("crop_size")
+
+ image_processor = cls(**image_processor_dict)
+
+ # Update image_processor with kwargs if needed
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(image_processor, key):
+ setattr(image_processor, key, value)
+ to_remove.append(key)
+ for key in to_remove:
+ kwargs.pop(key, None)
+
+ logger.info(f"Image processor {image_processor}")
+ if return_unused_kwargs:
+ return image_processor, kwargs
+ else:
+ return image_processor
+
+ def to_dict(self) -> dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary.
+
+ Returns:
+ `dict[str, Any]`: Dictionary of all the attributes that make up this image processor instance.
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["image_processor_type"] = self.__class__.__name__
+
+ return output
+
+ @classmethod
+ def from_json_file(cls, json_file: Union[str, os.PathLike]):
+ """
+ Instantiates a image processor of type [`~image_processing_utils.ImageProcessingMixin`] from the path to a JSON
+ file of parameters.
+
+ Args:
+ json_file (`str` or `os.PathLike`):
+ Path to the JSON file containing the parameters.
+
+ Returns:
+ A image processor of type [`~image_processing_utils.ImageProcessingMixin`]: The image_processor object
+ instantiated from that JSON file.
+ """
+ with open(json_file, encoding="utf-8") as reader:
+ text = reader.read()
+ image_processor_dict = json.loads(text)
+ return cls(**image_processor_dict)
+
+ def to_json_string(self) -> str:
+ """
+ Serializes this instance to a JSON string.
+
+ Returns:
+ `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
+ """
+ dictionary = self.to_dict()
+
+ for key, value in dictionary.items():
+ if isinstance(value, np.ndarray):
+ dictionary[key] = value.tolist()
+
+ # make sure private name "_processor_class" is correctly
+ # saved as "processor_class"
+ _processor_class = dictionary.pop("_processor_class", None)
+ if _processor_class is not None:
+ dictionary["processor_class"] = _processor_class
+
+ return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this image_processor instance's parameters will be saved.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
+
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ @classmethod
+ def register_for_auto_class(cls, auto_class="AutoImageProcessor"):
+ """
+ Register this class with a given auto class. This should only be used for custom image processors as the ones
+ in the library are already mapped with `AutoImageProcessor `.
+
+
+
+ Args:
+ auto_class (`str` or `type`, *optional*, defaults to `"AutoImageProcessor "`):
+ The auto class to register this new image processor with.
+ """
+ if not isinstance(auto_class, str):
+ auto_class = auto_class.__name__
+
+ import transformers.models.auto as auto_module
+
+ if not hasattr(auto_module, auto_class):
+ raise ValueError(f"{auto_class} is not a valid auto class.")
+
+ cls._auto_class = auto_class
+
+ def fetch_images(self, image_url_or_urls: Union[str, list[str], list[list[str]]]):
+ """
+ Convert a single or a list of urls into the corresponding `PIL.Image` objects.
+
+ If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
+ returned.
+ """
+ if isinstance(image_url_or_urls, list):
+ return [self.fetch_images(x) for x in image_url_or_urls]
+ elif isinstance(image_url_or_urls, str):
+ return load_image(image_url_or_urls)
+ elif is_valid_image(image_url_or_urls):
+ return image_url_or_urls
+ else:
+ raise TypeError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
+
+
+ImageProcessingMixin.push_to_hub = copy_func(ImageProcessingMixin.push_to_hub)
+if ImageProcessingMixin.push_to_hub.__doc__ is not None:
+ ImageProcessingMixin.push_to_hub.__doc__ = ImageProcessingMixin.push_to_hub.__doc__.format(
+ object="image processor", object_class="AutoImageProcessor", object_files="image processor file"
+ )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/image_processing_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/image_processing_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..52b798c09f84cf014b7568dd0634d739d32e0508
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/image_processing_utils.py
@@ -0,0 +1,317 @@
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections.abc import Iterable
+from typing import Optional, Union
+
+import numpy as np
+
+from .image_processing_base import BatchFeature, ImageProcessingMixin
+from .image_transforms import center_crop, normalize, rescale
+from .image_utils import ChannelDimension, get_image_size
+from .utils import logging
+from .utils.import_utils import requires
+
+
+logger = logging.get_logger(__name__)
+
+
+INIT_SERVICE_KWARGS = [
+ "processor_class",
+ "image_processor_type",
+]
+
+
+@requires(backends=("vision",))
+class BaseImageProcessor(ImageProcessingMixin):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ @property
+ def is_fast(self) -> bool:
+ """
+ `bool`: Whether or not this image processor is a fast processor (backed by PyTorch and TorchVision).
+ """
+ return False
+
+ def __call__(self, images, **kwargs) -> BatchFeature:
+ """Preprocess an image or a batch of images."""
+ return self.preprocess(images, **kwargs)
+
+ def preprocess(self, images, **kwargs) -> BatchFeature:
+ raise NotImplementedError("Each image processor must implement its own preprocess method")
+
+ def rescale(
+ self,
+ image: np.ndarray,
+ scale: float,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Rescale an image by a scale factor. image = image * scale.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ scale (`float`):
+ The scaling factor to rescale pixel values by.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+
+ Returns:
+ `np.ndarray`: The rescaled image.
+ """
+ return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)
+
+ def normalize(
+ self,
+ image: np.ndarray,
+ mean: Union[float, Iterable[float]],
+ std: Union[float, Iterable[float]],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Normalize an image. image = (image - image_mean) / image_std.
+
+ Args:
+ image (`np.ndarray`):
+ Image to normalize.
+ mean (`float` or `Iterable[float]`):
+ Image mean to use for normalization.
+ std (`float` or `Iterable[float]`):
+ Image standard deviation to use for normalization.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+
+ Returns:
+ `np.ndarray`: The normalized image.
+ """
+ return normalize(
+ image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
+ )
+
+ def center_crop(
+ self,
+ image: np.ndarray,
+ size: dict[str, int],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
+ any edge, the image is padded with 0's and then center cropped.
+
+ Args:
+ image (`np.ndarray`):
+ Image to center crop.
+ size (`dict[str, int]`):
+ Size of the output image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
+ return center_crop(
+ image,
+ size=(size["height"], size["width"]),
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ def to_dict(self):
+ encoder_dict = super().to_dict()
+ encoder_dict.pop("_valid_processor_keys", None)
+ return encoder_dict
+
+
+VALID_SIZE_DICT_KEYS = (
+ {"height", "width"},
+ {"shortest_edge"},
+ {"shortest_edge", "longest_edge"},
+ {"longest_edge"},
+ {"max_height", "max_width"},
+)
+
+
+def is_valid_size_dict(size_dict):
+ if not isinstance(size_dict, dict):
+ return False
+
+ size_dict_keys = set(size_dict.keys())
+ for allowed_keys in VALID_SIZE_DICT_KEYS:
+ if size_dict_keys == allowed_keys:
+ return True
+ return False
+
+
+def convert_to_size_dict(
+ size, max_size: Optional[int] = None, default_to_square: bool = True, height_width_order: bool = True
+):
+ # By default, if size is an int we assume it represents a tuple of (size, size).
+ if isinstance(size, int) and default_to_square:
+ if max_size is not None:
+ raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
+ return {"height": size, "width": size}
+ # In other configs, if size is an int and default_to_square is False, size represents the length of
+ # the shortest edge after resizing.
+ elif isinstance(size, int) and not default_to_square:
+ size_dict = {"shortest_edge": size}
+ if max_size is not None:
+ size_dict["longest_edge"] = max_size
+ return size_dict
+ # Otherwise, if size is a tuple it's either (height, width) or (width, height)
+ elif isinstance(size, (tuple, list)) and height_width_order:
+ return {"height": size[0], "width": size[1]}
+ elif isinstance(size, (tuple, list)) and not height_width_order:
+ return {"height": size[1], "width": size[0]}
+ elif size is None and max_size is not None:
+ if default_to_square:
+ raise ValueError("Cannot specify both default_to_square=True and max_size")
+ return {"longest_edge": max_size}
+
+ raise ValueError(f"Could not convert size input to size dict: {size}")
+
+
+def get_size_dict(
+ size: Optional[Union[int, Iterable[int], dict[str, int]]] = None,
+ max_size: Optional[int] = None,
+ height_width_order: bool = True,
+ default_to_square: bool = True,
+ param_name="size",
+) -> dict:
+ """
+ Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards
+ compatibility with the old image processor configs and removes ambiguity over whether the tuple is in (height,
+ width) or (width, height) format.
+
+ - If `size` is tuple, it is converted to `{"height": size[0], "width": size[1]}` or `{"height": size[1], "width":
+ size[0]}` if `height_width_order` is `False`.
+ - If `size` is an int, and `default_to_square` is `True`, it is converted to `{"height": size, "width": size}`.
+ - If `size` is an int and `default_to_square` is False, it is converted to `{"shortest_edge": size}`. If `max_size`
+ is set, it is added to the dict as `{"longest_edge": max_size}`.
+
+ Args:
+ size (`Union[int, Iterable[int], dict[str, int]]`, *optional*):
+ The `size` parameter to be cast into a size dictionary.
+ max_size (`Optional[int]`, *optional*):
+ The `max_size` parameter to be cast into a size dictionary.
+ height_width_order (`bool`, *optional*, defaults to `True`):
+ If `size` is a tuple, whether it's in (height, width) or (width, height) order.
+ default_to_square (`bool`, *optional*, defaults to `True`):
+ If `size` is an int, whether to default to a square image or not.
+ """
+ if not isinstance(size, dict):
+ size_dict = convert_to_size_dict(size, max_size, default_to_square, height_width_order)
+ logger.info(
+ f"{param_name} should be a dictionary on of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size}."
+ f" Converted to {size_dict}.",
+ )
+ else:
+ size_dict = size
+
+ if not is_valid_size_dict(size_dict):
+ raise ValueError(
+ f"{param_name} must have one of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size_dict.keys()}"
+ )
+ return size_dict
+
+
+def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
+ """
+ Selects the best resolution from a list of possible resolutions based on the original size.
+
+ This is done by calculating the effective and wasted resolution for each possible resolution.
+
+ The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
+
+ Args:
+ original_size (tuple):
+ The original size of the image in the format (height, width).
+ possible_resolutions (list):
+ A list of possible resolutions in the format [(height1, width1), (height2, width2), ...].
+
+ Returns:
+ tuple: The best fit resolution in the format (height, width).
+ """
+ original_height, original_width = original_size
+ best_fit = None
+ max_effective_resolution = 0
+ min_wasted_resolution = float("inf")
+
+ for height, width in possible_resolutions:
+ scale = min(width / original_width, height / original_height)
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
+ wasted_resolution = (width * height) - effective_resolution
+
+ if effective_resolution > max_effective_resolution or (
+ effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
+ ):
+ max_effective_resolution = effective_resolution
+ min_wasted_resolution = wasted_resolution
+ best_fit = (height, width)
+
+ return best_fit
+
+
+def get_patch_output_size(image, target_resolution, input_data_format):
+ """
+ Given an image and a target resolution, calculate the output size of the image after cropping to the target
+ """
+ original_height, original_width = get_image_size(image, channel_dim=input_data_format)
+ target_height, target_width = target_resolution
+
+ scale_w = target_width / original_width
+ scale_h = target_height / original_height
+
+ if scale_w < scale_h:
+ new_width = target_width
+ new_height = min(math.ceil(original_height * scale_w), target_height)
+ else:
+ new_height = target_height
+ new_width = min(math.ceil(original_width * scale_h), target_width)
+
+ return new_height, new_width
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/image_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/image_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5f4d4a3fa4cd2540d76c919a6548f52cbed76b4
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/image_utils.py
@@ -0,0 +1,969 @@
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import os
+from collections.abc import Iterable
+from dataclasses import dataclass
+from io import BytesIO
+from typing import Optional, Union
+
+import numpy as np
+import requests
+
+from .utils import (
+ ExplicitEnum,
+ is_jax_tensor,
+ is_numpy_array,
+ is_tf_tensor,
+ is_torch_available,
+ is_torch_tensor,
+ is_torchvision_available,
+ is_vision_available,
+ logging,
+ requires_backends,
+ to_numpy,
+)
+from .utils.constants import ( # noqa: F401
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+)
+
+
+if is_vision_available():
+ import PIL.Image
+ import PIL.ImageOps
+
+ PILImageResampling = PIL.Image.Resampling
+
+ if is_torchvision_available():
+ from torchvision.transforms import InterpolationMode
+
+ pil_torch_interpolation_mapping = {
+ PILImageResampling.NEAREST: InterpolationMode.NEAREST_EXACT,
+ PILImageResampling.BOX: InterpolationMode.BOX,
+ PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
+ PILImageResampling.HAMMING: InterpolationMode.HAMMING,
+ PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
+ PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
+ }
+ else:
+ pil_torch_interpolation_mapping = {}
+
+
+if is_torch_available():
+ import torch
+
+
+logger = logging.get_logger(__name__)
+
+
+ImageInput = Union[
+ "PIL.Image.Image", np.ndarray, "torch.Tensor", list["PIL.Image.Image"], list[np.ndarray], list["torch.Tensor"]
+]
+
+
+class ChannelDimension(ExplicitEnum):
+ FIRST = "channels_first"
+ LAST = "channels_last"
+
+
+class AnnotationFormat(ExplicitEnum):
+ COCO_DETECTION = "coco_detection"
+ COCO_PANOPTIC = "coco_panoptic"
+
+
+class AnnotionFormat(ExplicitEnum):
+ COCO_DETECTION = AnnotationFormat.COCO_DETECTION.value
+ COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value
+
+
+AnnotationType = dict[str, Union[int, str, list[dict]]]
+
+
+def is_pil_image(img):
+ return is_vision_available() and isinstance(img, PIL.Image.Image)
+
+
+class ImageType(ExplicitEnum):
+ PIL = "pillow"
+ TORCH = "torch"
+ NUMPY = "numpy"
+ TENSORFLOW = "tensorflow"
+ JAX = "jax"
+
+
+def get_image_type(image):
+ if is_pil_image(image):
+ return ImageType.PIL
+ if is_torch_tensor(image):
+ return ImageType.TORCH
+ if is_numpy_array(image):
+ return ImageType.NUMPY
+ if is_tf_tensor(image):
+ return ImageType.TENSORFLOW
+ if is_jax_tensor(image):
+ return ImageType.JAX
+ raise ValueError(f"Unrecognized image type {type(image)}")
+
+
+def is_valid_image(img):
+ return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img) or is_tf_tensor(img) or is_jax_tensor(img)
+
+
+def is_valid_list_of_images(images: list):
+ return images and all(is_valid_image(image) for image in images)
+
+
+def concatenate_list(input_list):
+ if isinstance(input_list[0], list):
+ return [item for sublist in input_list for item in sublist]
+ elif isinstance(input_list[0], np.ndarray):
+ return np.concatenate(input_list, axis=0)
+ elif isinstance(input_list[0], torch.Tensor):
+ return torch.cat(input_list, dim=0)
+
+
+def valid_images(imgs):
+ # If we have an list of images, make sure every image is valid
+ if isinstance(imgs, (list, tuple)):
+ for img in imgs:
+ if not valid_images(img):
+ return False
+ # If not a list of tuple, we have been given a single image or batched tensor of images
+ elif not is_valid_image(imgs):
+ return False
+ return True
+
+
+def is_batched(img):
+ if isinstance(img, (list, tuple)):
+ return is_valid_image(img[0])
+ return False
+
+
+def is_scaled_image(image: np.ndarray) -> bool:
+ """
+ Checks to see whether the pixel values have already been rescaled to [0, 1].
+ """
+ if image.dtype == np.uint8:
+ return False
+
+ # It's possible the image has pixel values in [0, 255] but is of floating type
+ return np.min(image) >= 0 and np.max(image) <= 1
+
+
+def make_list_of_images(images, expected_ndims: int = 3) -> list[ImageInput]:
+ """
+ Ensure that the output is a list of images. If the input is a single image, it is converted to a list of length 1.
+ If the input is a batch of images, it is converted to a list of images.
+
+ Args:
+ images (`ImageInput`):
+ Image of images to turn into a list of images.
+ expected_ndims (`int`, *optional*, defaults to 3):
+ Expected number of dimensions for a single input image. If the input image has a different number of
+ dimensions, an error is raised.
+ """
+ if is_batched(images):
+ return images
+
+ # Either the input is a single image, in which case we create a list of length 1
+ if is_pil_image(images):
+ # PIL images are never batched
+ return [images]
+
+ if is_valid_image(images):
+ if images.ndim == expected_ndims + 1:
+ # Batch of images
+ images = list(images)
+ elif images.ndim == expected_ndims:
+ # Single image
+ images = [images]
+ else:
+ raise ValueError(
+ f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
+ f" {images.ndim} dimensions."
+ )
+ return images
+ raise ValueError(
+ "Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or "
+ f"jax.ndarray, but got {type(images)}."
+ )
+
+
+def make_flat_list_of_images(
+ images: Union[list[ImageInput], ImageInput],
+ expected_ndims: int = 3,
+) -> ImageInput:
+ """
+ Ensure that the output is a flat list of images. If the input is a single image, it is converted to a list of length 1.
+ If the input is a nested list of images, it is converted to a flat list of images.
+ Args:
+ images (`Union[list[ImageInput], ImageInput]`):
+ The input image.
+ expected_ndims (`int`, *optional*, defaults to 3):
+ The expected number of dimensions for a single input image.
+ Returns:
+ list: A list of images or a 4d array of images.
+ """
+ # If the input is a nested list of images, we flatten it
+ if (
+ isinstance(images, (list, tuple))
+ and all(isinstance(images_i, (list, tuple)) for images_i in images)
+ and all(is_valid_list_of_images(images_i) or not images_i for images_i in images)
+ ):
+ return [img for img_list in images for img in img_list]
+
+ if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
+ if is_pil_image(images[0]) or images[0].ndim == expected_ndims:
+ return images
+ if images[0].ndim == expected_ndims + 1:
+ return [img for img_list in images for img in img_list]
+
+ if is_valid_image(images):
+ if is_pil_image(images) or images.ndim == expected_ndims:
+ return [images]
+ if images.ndim == expected_ndims + 1:
+ return list(images)
+
+ raise ValueError(f"Could not make a flat list of images from {images}")
+
+
+def make_nested_list_of_images(
+ images: Union[list[ImageInput], ImageInput],
+ expected_ndims: int = 3,
+) -> list[ImageInput]:
+ """
+ Ensure that the output is a nested list of images.
+ Args:
+ images (`Union[list[ImageInput], ImageInput]`):
+ The input image.
+ expected_ndims (`int`, *optional*, defaults to 3):
+ The expected number of dimensions for a single input image.
+ Returns:
+ list: A list of list of images or a list of 4d array of images.
+ """
+ # If it's a list of batches, it's already in the right format
+ if (
+ isinstance(images, (list, tuple))
+ and all(isinstance(images_i, (list, tuple)) for images_i in images)
+ and all(is_valid_list_of_images(images_i) or not images_i for images_i in images)
+ ):
+ return images
+
+ # If it's a list of images, it's a single batch, so convert it to a list of lists
+ if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
+ if is_pil_image(images[0]) or images[0].ndim == expected_ndims:
+ return [images]
+ if images[0].ndim == expected_ndims + 1:
+ return [list(image) for image in images]
+
+ # If it's a single image, convert it to a list of lists
+ if is_valid_image(images):
+ if is_pil_image(images) or images.ndim == expected_ndims:
+ return [[images]]
+ if images.ndim == expected_ndims + 1:
+ return [list(images)]
+
+ raise ValueError("Invalid input type. Must be a single image, a list of images, or a list of batches of images.")
+
+
+def to_numpy_array(img) -> np.ndarray:
+ if not is_valid_image(img):
+ raise ValueError(f"Invalid image type: {type(img)}")
+
+ if is_vision_available() and isinstance(img, PIL.Image.Image):
+ return np.array(img)
+ return to_numpy(img)
+
+
+def infer_channel_dimension_format(
+ image: np.ndarray, num_channels: Optional[Union[int, tuple[int, ...]]] = None
+) -> ChannelDimension:
+ """
+ Infers the channel dimension format of `image`.
+
+ Args:
+ image (`np.ndarray`):
+ The image to infer the channel dimension of.
+ num_channels (`int` or `tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
+ The number of channels of the image.
+
+ Returns:
+ The channel dimension of the image.
+ """
+ num_channels = num_channels if num_channels is not None else (1, 3)
+ num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
+
+ if image.ndim == 3:
+ first_dim, last_dim = 0, 2
+ elif image.ndim == 4:
+ first_dim, last_dim = 1, 3
+ elif image.ndim == 5:
+ first_dim, last_dim = 2, 4
+ else:
+ raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
+
+ if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels:
+ logger.warning(
+ f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension. Use the [input_data_format](https://huggingface.co/docs/transformers/main/internal/image_processing_utils#transformers.image_transforms.rescale.input_data_format) parameter to assign the channel dimension."
+ )
+ return ChannelDimension.FIRST
+ elif image.shape[first_dim] in num_channels:
+ return ChannelDimension.FIRST
+ elif image.shape[last_dim] in num_channels:
+ return ChannelDimension.LAST
+ raise ValueError("Unable to infer channel dimension format")
+
+
+def get_channel_dimension_axis(
+ image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None
+) -> int:
+ """
+ Returns the channel dimension axis of the image.
+
+ Args:
+ image (`np.ndarray`):
+ The image to get the channel dimension axis of.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
+
+ Returns:
+ The channel dimension axis of the image.
+ """
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+ if input_data_format == ChannelDimension.FIRST:
+ return image.ndim - 3
+ elif input_data_format == ChannelDimension.LAST:
+ return image.ndim - 1
+ raise ValueError(f"Unsupported data format: {input_data_format}")
+
+
+def get_image_size(image: np.ndarray, channel_dim: Optional[ChannelDimension] = None) -> tuple[int, int]:
+ """
+ Returns the (height, width) dimensions of the image.
+
+ Args:
+ image (`np.ndarray`):
+ The image to get the dimensions of.
+ channel_dim (`ChannelDimension`, *optional*):
+ Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.
+
+ Returns:
+ A tuple of the image's height and width.
+ """
+ if channel_dim is None:
+ channel_dim = infer_channel_dimension_format(image)
+
+ if channel_dim == ChannelDimension.FIRST:
+ return image.shape[-2], image.shape[-1]
+ elif channel_dim == ChannelDimension.LAST:
+ return image.shape[-3], image.shape[-2]
+ else:
+ raise ValueError(f"Unsupported data format: {channel_dim}")
+
+
+def get_image_size_for_max_height_width(
+ image_size: tuple[int, int],
+ max_height: int,
+ max_width: int,
+) -> tuple[int, int]:
+ """
+ Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
+ Important, even if image_height < max_height and image_width < max_width, the image will be resized
+ to at least one of the edges be equal to max_height or max_width.
+
+ For example:
+ - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
+ - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
+
+ Args:
+ image_size (`tuple[int, int]`):
+ The image to resize.
+ max_height (`int`):
+ The maximum allowed height.
+ max_width (`int`):
+ The maximum allowed width.
+ """
+ height, width = image_size
+ height_scale = max_height / height
+ width_scale = max_width / width
+ min_scale = min(height_scale, width_scale)
+ new_height = int(height * min_scale)
+ new_width = int(width * min_scale)
+ return new_height, new_width
+
+
+def is_valid_annotation_coco_detection(annotation: dict[str, Union[list, tuple]]) -> bool:
+ if (
+ isinstance(annotation, dict)
+ and "image_id" in annotation
+ and "annotations" in annotation
+ and isinstance(annotation["annotations"], (list, tuple))
+ and (
+ # an image can have no annotations
+ len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict)
+ )
+ ):
+ return True
+ return False
+
+
+def is_valid_annotation_coco_panoptic(annotation: dict[str, Union[list, tuple]]) -> bool:
+ if (
+ isinstance(annotation, dict)
+ and "image_id" in annotation
+ and "segments_info" in annotation
+ and "file_name" in annotation
+ and isinstance(annotation["segments_info"], (list, tuple))
+ and (
+ # an image can have no segments
+ len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict)
+ )
+ ):
+ return True
+ return False
+
+
+def valid_coco_detection_annotations(annotations: Iterable[dict[str, Union[list, tuple]]]) -> bool:
+ return all(is_valid_annotation_coco_detection(ann) for ann in annotations)
+
+
+def valid_coco_panoptic_annotations(annotations: Iterable[dict[str, Union[list, tuple]]]) -> bool:
+ return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations)
+
+
+def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image":
+ """
+ Loads `image` to a PIL Image.
+
+ Args:
+ image (`str` or `PIL.Image.Image`):
+ The image to convert to the PIL Image format.
+ timeout (`float`, *optional*):
+ The timeout value in seconds for the URL request.
+
+ Returns:
+ `PIL.Image.Image`: A PIL Image.
+ """
+ requires_backends(load_image, ["vision"])
+ if isinstance(image, str):
+ if image.startswith("http://") or image.startswith("https://"):
+ # We need to actually check for a real protocol, otherwise it's impossible to use a local file
+ # like http_huggingface_co.png
+ image = PIL.Image.open(BytesIO(requests.get(image, timeout=timeout).content))
+ elif os.path.isfile(image):
+ image = PIL.Image.open(image)
+ else:
+ if image.startswith("data:image/"):
+ image = image.split(",")[1]
+
+ # Try to load as base64
+ try:
+ b64 = base64.decodebytes(image.encode())
+ image = PIL.Image.open(BytesIO(b64))
+ except Exception as e:
+ raise ValueError(
+ f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
+ )
+ elif not isinstance(image, PIL.Image.Image):
+ raise TypeError(
+ "Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image."
+ )
+ image = PIL.ImageOps.exif_transpose(image)
+ image = image.convert("RGB")
+ return image
+
+
+def load_images(
+ images: Union[list, tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None
+) -> Union["PIL.Image.Image", list["PIL.Image.Image"], list[list["PIL.Image.Image"]]]:
+ """Loads images, handling different levels of nesting.
+
+ Args:
+ images: A single image, a list of images, or a list of lists of images to load.
+ timeout: Timeout for loading images.
+
+ Returns:
+ A single image, a list of images, a list of lists of images.
+ """
+ if isinstance(images, (list, tuple)):
+ if len(images) and isinstance(images[0], (list, tuple)):
+ return [[load_image(image, timeout=timeout) for image in image_group] for image_group in images]
+ else:
+ return [load_image(image, timeout=timeout) for image in images]
+ else:
+ return load_image(images, timeout=timeout)
+
+
+def validate_preprocess_arguments(
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_pad: Optional[bool] = None,
+ pad_size: Optional[Union[dict[str, int], int]] = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[dict[str, int]] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[dict[str, int]] = None,
+ resample: Optional["PILImageResampling"] = None,
+ interpolation: Optional["InterpolationMode"] = None,
+):
+ """
+ Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
+ Raises `ValueError` if arguments incompatibility is caught.
+ Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
+ sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
+ existing arguments when possible.
+
+ """
+ if do_rescale and rescale_factor is None:
+ raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.")
+
+ if do_pad and pad_size is None:
+ # Processors pad images using different args depending on the model, so the below check is pointless
+ # but we keep it for BC for now. TODO: remove in v5
+ # Usually padding can be called with:
+ # - "pad_size/size" if we're padding to specific values
+ # - "size_divisor" if we're padding to any value divisible by X
+ # - "None" if we're padding to the maximum size image in batch
+ raise ValueError(
+ "Depending on the model, `size_divisor` or `pad_size` or `size` must be specified if `do_pad` is `True`."
+ )
+
+ if do_normalize and (image_mean is None or image_std is None):
+ raise ValueError("`image_mean` and `image_std` must both be specified if `do_normalize` is `True`.")
+
+ if do_center_crop and crop_size is None:
+ raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.")
+
+ if interpolation is not None and resample is not None:
+ raise ValueError(
+ "Only one of `interpolation` and `resample` should be specified, depending on image processor type."
+ )
+
+ if do_resize and not (size is not None and (resample is not None or interpolation is not None)):
+ raise ValueError("`size` and `resample/interpolation` must be specified if `do_resize` is `True`.")
+
+
+# In the future we can add a TF implementation here when we have TF models.
+class ImageFeatureExtractionMixin:
+ """
+ Mixin that contain utilities for preparing image features.
+ """
+
+ def _ensure_format_supported(self, image):
+ if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image):
+ raise ValueError(
+ f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.ndarray` and "
+ "`torch.Tensor` are."
+ )
+
+ def to_pil_image(self, image, rescale=None):
+ """
+ Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
+ needed.
+
+ Args:
+ image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
+ The image to convert to the PIL Image format.
+ rescale (`bool`, *optional*):
+ Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
+ default to `True` if the image type is a floating type, `False` otherwise.
+ """
+ self._ensure_format_supported(image)
+
+ if is_torch_tensor(image):
+ image = image.numpy()
+
+ if isinstance(image, np.ndarray):
+ if rescale is None:
+ # rescale default to the array being of floating type.
+ rescale = isinstance(image.flat[0], np.floating)
+ # If the channel as been moved to first dim, we put it back at the end.
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
+ image = image.transpose(1, 2, 0)
+ if rescale:
+ image = image * 255
+ image = image.astype(np.uint8)
+ return PIL.Image.fromarray(image)
+ return image
+
+ def convert_rgb(self, image):
+ """
+ Converts `PIL.Image.Image` to RGB format.
+
+ Args:
+ image (`PIL.Image.Image`):
+ The image to convert.
+ """
+ self._ensure_format_supported(image)
+ if not isinstance(image, PIL.Image.Image):
+ return image
+
+ return image.convert("RGB")
+
+ def rescale(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray:
+ """
+ Rescale a numpy image by scale amount
+ """
+ self._ensure_format_supported(image)
+ return image * scale
+
+ def to_numpy_array(self, image, rescale=None, channel_first=True):
+ """
+ Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first
+ dimension.
+
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
+ The image to convert to a NumPy array.
+ rescale (`bool`, *optional*):
+ Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will
+ default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise.
+ channel_first (`bool`, *optional*, defaults to `True`):
+ Whether or not to permute the dimensions of the image to put the channel dimension first.
+ """
+ self._ensure_format_supported(image)
+
+ if isinstance(image, PIL.Image.Image):
+ image = np.array(image)
+
+ if is_torch_tensor(image):
+ image = image.numpy()
+
+ rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale
+
+ if rescale:
+ image = self.rescale(image.astype(np.float32), 1 / 255.0)
+
+ if channel_first and image.ndim == 3:
+ image = image.transpose(2, 0, 1)
+
+ return image
+
+ def expand_dims(self, image):
+ """
+ Expands 2-dimensional `image` to 3 dimensions.
+
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
+ The image to expand.
+ """
+ self._ensure_format_supported(image)
+
+ # Do nothing if PIL image
+ if isinstance(image, PIL.Image.Image):
+ return image
+
+ if is_torch_tensor(image):
+ image = image.unsqueeze(0)
+ else:
+ image = np.expand_dims(image, axis=0)
+ return image
+
+ def normalize(self, image, mean, std, rescale=False):
+ """
+ Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array
+ if it's a PIL Image.
+
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
+ The image to normalize.
+ mean (`list[float]` or `np.ndarray` or `torch.Tensor`):
+ The mean (per channel) to use for normalization.
+ std (`list[float]` or `np.ndarray` or `torch.Tensor`):
+ The standard deviation (per channel) to use for normalization.
+ rescale (`bool`, *optional*, defaults to `False`):
+ Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will
+ happen automatically.
+ """
+ self._ensure_format_supported(image)
+
+ if isinstance(image, PIL.Image.Image):
+ image = self.to_numpy_array(image, rescale=True)
+ # If the input image is a PIL image, it automatically gets rescaled. If it's another
+ # type it may need rescaling.
+ elif rescale:
+ if isinstance(image, np.ndarray):
+ image = self.rescale(image.astype(np.float32), 1 / 255.0)
+ elif is_torch_tensor(image):
+ image = self.rescale(image.float(), 1 / 255.0)
+
+ if isinstance(image, np.ndarray):
+ if not isinstance(mean, np.ndarray):
+ mean = np.array(mean).astype(image.dtype)
+ if not isinstance(std, np.ndarray):
+ std = np.array(std).astype(image.dtype)
+ elif is_torch_tensor(image):
+ import torch
+
+ if not isinstance(mean, torch.Tensor):
+ if isinstance(mean, np.ndarray):
+ mean = torch.from_numpy(mean)
+ else:
+ mean = torch.tensor(mean)
+ if not isinstance(std, torch.Tensor):
+ if isinstance(std, np.ndarray):
+ std = torch.from_numpy(std)
+ else:
+ std = torch.tensor(std)
+
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
+ return (image - mean[:, None, None]) / std[:, None, None]
+ else:
+ return (image - mean) / std
+
+ def resize(self, image, size, resample=None, default_to_square=True, max_size=None):
+ """
+ Resizes `image`. Enforces conversion of input to PIL.Image.
+
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
+ The image to resize.
+ size (`int` or `tuple[int, int]`):
+ The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be
+ matched to this.
+
+ If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
+ `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to
+ this number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
+ resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ The filter to user for resampling.
+ default_to_square (`bool`, *optional*, defaults to `True`):
+ How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a
+ square (`size`,`size`). If set to `False`, will replicate
+ [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
+ with support for resizing only the smallest edge and providing an optional `max_size`.
+ max_size (`int`, *optional*, defaults to `None`):
+ The maximum allowed for the longer edge of the resized image: if the longer edge of the image is
+ greater than `max_size` after being resized according to `size`, then the image is resized again so
+ that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller
+ edge may be shorter than `size`. Only used if `default_to_square` is `False`.
+
+ Returns:
+ image: A resized `PIL.Image.Image`.
+ """
+ resample = resample if resample is not None else PILImageResampling.BILINEAR
+
+ self._ensure_format_supported(image)
+
+ if not isinstance(image, PIL.Image.Image):
+ image = self.to_pil_image(image)
+
+ if isinstance(size, list):
+ size = tuple(size)
+
+ if isinstance(size, int) or len(size) == 1:
+ if default_to_square:
+ size = (size, size) if isinstance(size, int) else (size[0], size[0])
+ else:
+ width, height = image.size
+ # specified size only for the smallest edge
+ short, long = (width, height) if width <= height else (height, width)
+ requested_new_short = size if isinstance(size, int) else size[0]
+
+ if short == requested_new_short:
+ return image
+
+ new_short, new_long = requested_new_short, int(requested_new_short * long / short)
+
+ if max_size is not None:
+ if max_size <= requested_new_short:
+ raise ValueError(
+ f"max_size = {max_size} must be strictly greater than the requested "
+ f"size for the smaller edge size = {size}"
+ )
+ if new_long > max_size:
+ new_short, new_long = int(max_size * new_short / new_long), max_size
+
+ size = (new_short, new_long) if width <= height else (new_long, new_short)
+
+ return image.resize(size, resample=resample)
+
+ def center_crop(self, image, size):
+ """
+ Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
+ size given, it will be padded (so the returned result has the size asked).
+
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)):
+ The image to resize.
+ size (`int` or `tuple[int, int]`):
+ The size to which crop the image.
+
+ Returns:
+ new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape: (n_channels,
+ height, width).
+ """
+ self._ensure_format_supported(image)
+
+ if not isinstance(size, tuple):
+ size = (size, size)
+
+ # PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width)
+ if is_torch_tensor(image) or isinstance(image, np.ndarray):
+ if image.ndim == 2:
+ image = self.expand_dims(image)
+ image_shape = image.shape[1:] if image.shape[0] in [1, 3] else image.shape[:2]
+ else:
+ image_shape = (image.size[1], image.size[0])
+
+ top = (image_shape[0] - size[0]) // 2
+ bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
+ left = (image_shape[1] - size[1]) // 2
+ right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
+
+ # For PIL Images we have a method to crop directly.
+ if isinstance(image, PIL.Image.Image):
+ return image.crop((left, top, right, bottom))
+
+ # Check if image is in (n_channels, height, width) or (height, width, n_channels) format
+ channel_first = image.shape[0] in [1, 3]
+
+ # Transpose (height, width, n_channels) format images
+ if not channel_first:
+ if isinstance(image, np.ndarray):
+ image = image.transpose(2, 0, 1)
+ if is_torch_tensor(image):
+ image = image.permute(2, 0, 1)
+
+ # Check if cropped area is within image boundaries
+ if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]:
+ return image[..., top:bottom, left:right]
+
+ # Otherwise, we may need to pad if the image is too small. Oh joy...
+ new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1]))
+ if isinstance(image, np.ndarray):
+ new_image = np.zeros_like(image, shape=new_shape)
+ elif is_torch_tensor(image):
+ new_image = image.new_zeros(new_shape)
+
+ top_pad = (new_shape[-2] - image_shape[0]) // 2
+ bottom_pad = top_pad + image_shape[0]
+ left_pad = (new_shape[-1] - image_shape[1]) // 2
+ right_pad = left_pad + image_shape[1]
+ new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
+
+ top += top_pad
+ bottom += top_pad
+ left += left_pad
+ right += left_pad
+
+ new_image = new_image[
+ ..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right)
+ ]
+
+ return new_image
+
+ def flip_channel_order(self, image):
+ """
+ Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of
+ `image` to a NumPy array if it's a PIL Image.
+
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
+ The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should
+ be first.
+ """
+ self._ensure_format_supported(image)
+
+ if isinstance(image, PIL.Image.Image):
+ image = self.to_numpy_array(image)
+
+ return image[::-1, :, :]
+
+ def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None):
+ """
+ Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees
+ counter clockwise around its centre.
+
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
+ The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before
+ rotating.
+
+ Returns:
+ image: A rotated `PIL.Image.Image`.
+ """
+ resample = resample if resample is not None else PIL.Image.NEAREST
+
+ self._ensure_format_supported(image)
+
+ if not isinstance(image, PIL.Image.Image):
+ image = self.to_pil_image(image)
+
+ return image.rotate(
+ angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor
+ )
+
+
+def validate_annotations(
+ annotation_format: AnnotationFormat,
+ supported_annotation_formats: tuple[AnnotationFormat, ...],
+ annotations: list[dict],
+) -> None:
+ if annotation_format not in supported_annotation_formats:
+ raise ValueError(f"Unsupported annotation format: {format} must be one of {supported_annotation_formats}")
+
+ if annotation_format is AnnotationFormat.COCO_DETECTION:
+ if not valid_coco_detection_annotations(annotations):
+ raise ValueError(
+ "Invalid COCO detection annotations. Annotations must a dict (single image) or list of dicts "
+ "(batch of images) with the following keys: `image_id` and `annotations`, with the latter "
+ "being a list of annotations in the COCO format."
+ )
+
+ if annotation_format is AnnotationFormat.COCO_PANOPTIC:
+ if not valid_coco_panoptic_annotations(annotations):
+ raise ValueError(
+ "Invalid COCO panoptic annotations. Annotations must a dict (single image) or list of dicts "
+ "(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with "
+ "the latter being a list of annotations in the COCO format."
+ )
+
+
+def validate_kwargs(valid_processor_keys: list[str], captured_kwargs: list[str]):
+ unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
+ if unused_keys:
+ unused_key_str = ", ".join(unused_keys)
+ # TODO raise a warning here instead of simply logging?
+ logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")
+
+
+@dataclass(frozen=True)
+class SizeDict:
+ """
+ Hashable dictionary to store image size information.
+ """
+
+ height: Optional[int] = None
+ width: Optional[int] = None
+ longest_edge: Optional[int] = None
+ shortest_edge: Optional[int] = None
+ max_height: Optional[int] = None
+ max_width: Optional[int] = None
+
+ def __getitem__(self, key):
+ if hasattr(self, key):
+ return getattr(self, key)
+ raise KeyError(f"Key {key} not found in SizeDict.")
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/keras_callbacks.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/keras_callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab7fc4615b473d59c903260a8c1ec80b24f4af7b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/keras_callbacks.py
@@ -0,0 +1,413 @@
+import logging
+import os
+from pathlib import Path
+from time import sleep
+from typing import Callable, Optional, Union
+
+import numpy as np
+import tensorflow as tf
+from huggingface_hub import Repository, create_repo
+from packaging.version import parse
+
+from . import IntervalStrategy, PreTrainedTokenizerBase
+from .modelcard import TrainingSummary
+from .modeling_tf_utils import keras
+
+
+logger = logging.getLogger(__name__)
+
+
+class KerasMetricCallback(keras.callbacks.Callback):
+ """
+ Callback to compute metrics at the end of every epoch. Unlike normal Keras metrics, these do not need to be
+ compilable by TF. It is particularly useful for common NLP metrics like BLEU and ROUGE that require string
+ operations or generation loops that cannot be compiled. Predictions (or generations) will be computed on the
+ `eval_dataset` before being passed to the `metric_fn` in `np.ndarray` format. The `metric_fn` should compute
+ metrics and return a dict mapping metric names to metric values.
+
+ We provide an example of a suitable metric_fn that computes ROUGE scores for a summarization model below. Note that
+ this example skips some post-processing for readability and simplicity, and should probably not be used as-is!
+
+ ```py
+ from datasets import load_metric
+
+ rouge_metric = load_metric("rouge")
+
+
+ def rouge_fn(predictions, labels):
+ decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
+ result = rouge_metric.compute(predictions=decoded_predictions, references=decoded_labels)
+ return {key: value.mid.fmeasure * 100 for key, value in result.items()}
+ ```
+
+ The above function will return a dict containing values which will be logged like any other Keras metric:
+
+ ```
+ {'rouge1': 37.4199, 'rouge2': 13.9768, 'rougeL': 34.361, 'rougeLsum': 35.0781
+ ```
+
+ Args:
+ metric_fn (`Callable`):
+ Metric function provided by the user. It will be called with two arguments - `predictions` and `labels`.
+ These contain the model's outputs and matching labels from the dataset. It should return a dict mapping
+ metric names to numerical values.
+ eval_dataset (`tf.data.Dataset` or `dict` or `tuple` or `np.ndarray` or `tf.Tensor`):
+ Validation data to be used to generate predictions for the `metric_fn`.
+ output_cols (`list[str], *optional*):
+ A list of columns to be retained from the model output as the predictions. Defaults to all.
+ label_cols ('`list[str]`, *optional*'):
+ A list of columns to be retained from the input dataset as the labels. Will be autodetected if this is not
+ supplied.
+ batch_size (`int`, *optional*):
+ Batch size. Only used when the data is not a pre-batched `tf.data.Dataset`.
+ predict_with_generate (`bool`, *optional*, defaults to `False`):
+ Whether we should use `model.generate()` to get outputs for the model.
+ use_xla_generation (`bool`, *optional*, defaults to `False`):
+ If we're generating, whether to compile model generation with XLA. This can massively increase the speed of
+ generation (up to 100X speedup) but will require a new XLA compilation for each input shape. When using XLA
+ generation, it's a good idea to pad your inputs to the same size, or to use the `pad_to_multiple_of`
+ argument in your `tokenizer` or `DataCollator`, which will reduce the number of unique input shapes and
+ save a lot of compilation time. This option has no effect is `predict_with_generate` is `False`.
+ generate_kwargs (`dict`, *optional*):
+ Keyword arguments to pass to `model.generate()` when generating. Has no effect if `predict_with_generate`
+ is `False`.
+
+ """
+
+ def __init__(
+ self,
+ metric_fn: Callable,
+ eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict],
+ output_cols: Optional[list[str]] = None,
+ label_cols: Optional[list[str]] = None,
+ batch_size: Optional[int] = None,
+ predict_with_generate: bool = False,
+ use_xla_generation: bool = False,
+ generate_kwargs: Optional[dict] = None,
+ ):
+ super().__init__()
+ self.metric_fn = metric_fn
+ self.batch_size = batch_size
+ if not isinstance(eval_dataset, tf.data.Dataset):
+ if batch_size is None:
+ raise ValueError(
+ "When passing data to KerasMetricCallback that is not a pre-batched tf.data.Dataset "
+ "the batch_size argument must be set."
+ )
+ # Wrap a tf.data.Dataset around it
+ eval_dataset = tf.data.Dataset.from_tensor_slices(eval_dataset).batch(batch_size, drop_remainder=False)
+ self.eval_dataset = eval_dataset
+ self.predict_with_generate = predict_with_generate
+ self.output_cols = output_cols
+
+ # This next block attempts to parse out which elements of the dataset should be appended to the labels list
+ # that is passed to the metric_fn
+ if isinstance(eval_dataset.element_spec, tuple) and len(eval_dataset.element_spec) == 2:
+ input_spec, label_spec = eval_dataset.element_spec
+ else:
+ input_spec = eval_dataset.element_spec
+ label_spec = None
+ if label_cols is not None:
+ for label in label_cols:
+ if label not in input_spec:
+ raise ValueError(f"Label {label} is in label_cols but could not be found in the dataset inputs!")
+ self.label_cols = label_cols
+ self.use_keras_label = False
+ elif label_spec is not None:
+ # If the dataset inputs are split into a 2-tuple of inputs and labels,
+ # assume the second element is the labels
+ self.label_cols = None
+ self.use_keras_label = True
+ elif "labels" in input_spec:
+ self.label_cols = ["labels"]
+ self.use_keras_label = False
+ logging.warning("No label_cols specified for KerasMetricCallback, assuming you want the 'labels' key.")
+ elif "start_positions" in input_spec and "end_positions" in input_spec:
+ self.label_cols = ["start_positions", "end_positions"]
+ self.use_keras_label = False
+ logging.warning(
+ "No label_cols specified for KerasMetricCallback, assuming you want the "
+ "start_positions and end_positions keys."
+ )
+ else:
+ raise ValueError("Could not autodetect label_cols for KerasMetricCallback, please specify them!")
+ if parse(tf.__version__) < parse("2.7"):
+ logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!")
+
+ self.use_xla_generation = use_xla_generation
+ self.generate_kwargs = {} if generate_kwargs is None else generate_kwargs
+
+ self.generation_function = None
+
+ @staticmethod
+ def _concatenate_batches(batches, padding_index=-100):
+ # If all batches are unidimensional or same length, do a simple concatenation
+ if batches[0].ndim == 1 or all(batch.shape[1] == batches[0].shape[1] for batch in batches):
+ return np.concatenate(batches, axis=0)
+
+ # Welp, they're not the same length. Let's do some padding
+ max_len = max([batch.shape[1] for batch in batches])
+ num_samples = sum([batch.shape[0] for batch in batches])
+ output = np.full_like(
+ batches[0], fill_value=padding_index, shape=[num_samples, max_len] + list(batches[0].shape[2:])
+ )
+ # i keeps track of which part of the concatenated array we're writing the next batch to
+ i = 0
+ for batch in batches:
+ output[i : i + len(batch), : batch.shape[1]] = batch
+ i += len(batch)
+ return output
+
+ def _postprocess_predictions_or_labels(self, inputs):
+ if isinstance(inputs[0], dict):
+ outputs = {}
+ for key in inputs[0]:
+ outputs[key] = self._concatenate_batches([batch[key] for batch in inputs])
+ # If it's a dict with only one key, just return the array
+ if len(outputs) == 1:
+ outputs = list(outputs.values())[0]
+ elif isinstance(inputs[0], (tuple, list)):
+ outputs = []
+ for input_list in zip(*inputs):
+ outputs.append(self._concatenate_batches(input_list))
+ if len(outputs) == 1:
+ outputs = outputs[0] # If it's a list with only one element, just return the array
+ elif isinstance(inputs[0], np.ndarray):
+ outputs = self._concatenate_batches(inputs)
+ elif isinstance(inputs[0], tf.Tensor):
+ outputs = self._concatenate_batches([tensor.numpy() for tensor in inputs])
+ else:
+ raise TypeError(f"Couldn't handle batch of type {type(inputs[0])}!")
+ return outputs
+
+ def on_epoch_end(self, epoch, logs=None):
+ if hasattr(self.model, "config"):
+ ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
+ else:
+ ignore_keys = []
+
+ main_input_name = None
+ if self.predict_with_generate:
+ # This dense conditional recognizes the case where we have an encoder-decoder model, but
+ # avoids getting tangled up when we just have a model with a layer called 'encoder'
+ if hasattr(self.model, "encoder") and hasattr(self.model.encoder, "main_input_name"):
+ main_input_name = self.model.encoder.main_input_name
+ else:
+ main_input_name = getattr(self.model, "main_input_name", "input_ids")
+
+ if self.use_xla_generation and self.generation_function is None:
+
+ def generation_function(inputs, attention_mask):
+ return self.model.generate(inputs, attention_mask=attention_mask, **self.generate_kwargs)
+
+ self.generation_function = tf.function(generation_function, jit_compile=True)
+
+ prediction_list = []
+ label_list = []
+
+ # The whole predict/generate loop is handled inside this method
+ for batch in self.eval_dataset:
+ if isinstance(batch, tuple):
+ batch, labels = batch
+ else:
+ labels = None
+ if self.predict_with_generate:
+ if isinstance(batch, dict):
+ generation_inputs = batch[main_input_name]
+ attention_mask = batch.get("attention_mask", None)
+ else:
+ generation_inputs = batch
+ attention_mask = None
+ if self.use_xla_generation:
+ predictions = self.generation_function(generation_inputs, attention_mask=attention_mask)
+ else:
+ predictions = self.model.generate(
+ generation_inputs, attention_mask=attention_mask, **self.generate_kwargs
+ )
+ else:
+ predictions = self.model.predict_on_batch(batch)
+ if isinstance(predictions, dict):
+ # This converts any dict-subclass to a regular dict
+ # Keras REALLY doesn't like it when we pass around a BatchEncoding or other derived class
+ predictions = dict(predictions)
+ if self.output_cols is not None:
+ predictions = {key: predictions[key] for key in self.output_cols}
+ else:
+ predictions = {
+ key: val for key, val in predictions.items() if key not in ignore_keys + ["loss"]
+ }
+ prediction_list.append(predictions)
+ if not self.use_keras_label:
+ labels = {key: batch[key].numpy() for key in self.label_cols}
+ elif isinstance(labels, dict):
+ labels = {key: array.numpy() for key, array in labels.items()}
+ elif isinstance(labels, (list, tuple)):
+ labels = [array.numpy() for array in labels]
+ elif isinstance(labels, tf.Tensor):
+ labels = labels.numpy()
+ else:
+ raise TypeError(f"Confused by labels of type {type(labels)}")
+ label_list.append(labels)
+
+ all_preds = self._postprocess_predictions_or_labels(prediction_list)
+ all_labels = self._postprocess_predictions_or_labels(label_list)
+
+ metric_output = self.metric_fn((all_preds, all_labels))
+ if not isinstance(metric_output, dict):
+ raise TypeError(
+ f"metric_fn should return a dict mapping metric names to values but instead returned {metric_output}"
+ )
+ # This is the critical bit - Keras passes a dict containing the loss and standard metric values for this epoch
+ # in the logs argument. Ordinarily, this is so the callback can read them, but in this case we write a bunch of
+ # new keys in there, which will then get read by the History callback and treated like any other metric value.
+ # I promise that I have it in writing from Chollet that this is okay.
+ logs.update(metric_output)
+
+
+class PushToHubCallback(keras.callbacks.Callback):
+ """
+ Callback that will save and push the model to the Hub regularly. By default, it pushes once per epoch, but this can
+ be changed with the `save_strategy` argument. Pushed models can be accessed like any other model on the hub, such
+ as with the `from_pretrained` method.
+
+ ```py
+ from transformers.keras_callbacks import PushToHubCallback
+
+ push_to_hub_callback = PushToHubCallback(
+ output_dir="./model_save",
+ tokenizer=tokenizer,
+ hub_model_id="gpt5-7xlarge",
+ )
+
+ model.fit(train_dataset, callbacks=[push_to_hub_callback])
+ ```
+
+ Args:
+ output_dir (`str`):
+ The output directory where the model predictions and checkpoints will be written and synced with the
+ repository on the Hub.
+ save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"epoch"`):
+ The checkpoint save strategy to adopt during training. Possible values are:
+
+ - `"no"`: Save is done at the end of training.
+ - `"epoch"`: Save is done at the end of each epoch.
+ - `"steps"`: Save is done every `save_steps`
+ save_steps (`int`, *optional*):
+ The number of steps between saves when using the "steps" `save_strategy`.
+ tokenizer (`PreTrainedTokenizerBase`, *optional*):
+ The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights.
+ hub_model_id (`str`, *optional*):
+ The name of the repository to keep in sync with the local `output_dir`. It can be a simple model ID in
+ which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,
+ for instance `"user_name/model"`, which allows you to push to an organization you are a member of with
+ `"organization_name/model"`.
+
+ Will default to the name of `output_dir`.
+ hub_token (`str`, *optional*):
+ The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
+ `hf auth login`.
+ checkpoint (`bool`, *optional*, defaults to `False`):
+ Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be
+ resumed. Only usable when `save_strategy` is `"epoch"`.
+ """
+
+ def __init__(
+ self,
+ output_dir: Union[str, Path],
+ save_strategy: Union[str, IntervalStrategy] = "epoch",
+ save_steps: Optional[int] = None,
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
+ hub_model_id: Optional[str] = None,
+ hub_token: Optional[str] = None,
+ checkpoint: bool = False,
+ **model_card_args,
+ ):
+ super().__init__()
+ if checkpoint and save_strategy != "epoch":
+ raise ValueError("Cannot save checkpoints when save_strategy is not 'epoch'!")
+ if isinstance(save_strategy, str):
+ save_strategy = IntervalStrategy(save_strategy.lower())
+ self.save_strategy = save_strategy
+ if self.save_strategy == IntervalStrategy.STEPS and (not isinstance(save_steps, int) or save_steps <= 0):
+ raise ValueError("Please supply a positive integer argument for save_steps when save_strategy == 'steps'!")
+ self.save_steps = save_steps
+ output_dir = Path(output_dir)
+
+ # Create repo and retrieve repo_id
+ if hub_model_id is None:
+ hub_model_id = output_dir.absolute().name
+ self.hub_model_id = create_repo(repo_id=hub_model_id, exist_ok=True, token=hub_token).repo_id
+
+ self.output_dir = output_dir
+ self.repo = Repository(str(self.output_dir), clone_from=self.hub_model_id, token=hub_token)
+
+ self.tokenizer = tokenizer
+ self.last_job = None
+ self.checkpoint = checkpoint
+ self.training_history = None
+ self.model_card_args = model_card_args
+
+ def on_train_begin(self, logs=None):
+ # Although we can access model.history, we have no guarantees that the History callback will fire before this
+ # one, so we keep track of it here too
+ self.training_history = []
+
+ def on_train_batch_end(self, batch, logs=None):
+ if self.save_strategy == IntervalStrategy.STEPS and (batch + 1) % self.save_steps == 0:
+ if self.last_job is not None and not self.last_job.is_done:
+ return # The last upload is still running, don't start another
+ self.model.save_pretrained(self.output_dir)
+ if self.tokenizer is not None:
+ self.tokenizer.save_pretrained(self.output_dir)
+ _, self.last_job = self.repo.push_to_hub(
+ commit_message=f"Training in progress steps {batch}", blocking=False
+ )
+
+ def on_epoch_end(self, epoch, logs=None):
+ logs = logs.copy() # Don't accidentally write things that Keras will read later
+ if "epoch" not in logs:
+ logs["epoch"] = epoch
+ self.training_history.append(logs)
+ if self.save_strategy == IntervalStrategy.EPOCH:
+ if self.last_job is not None and not self.last_job.is_done:
+ return # The last upload is still running, don't start another
+ self.model.save_pretrained(self.output_dir)
+ if self.tokenizer is not None:
+ self.tokenizer.save_pretrained(self.output_dir)
+ if self.checkpoint:
+ checkpoint_dir = os.path.join(self.output_dir, "checkpoint")
+ self.model._save_checkpoint(checkpoint_dir, epoch)
+ train_summary = TrainingSummary.from_keras(
+ model=self.model,
+ model_name=self.hub_model_id,
+ keras_history=self.training_history,
+ **self.model_card_args,
+ )
+ model_card = train_summary.to_model_card()
+ with (self.output_dir / "README.md").open("w") as f:
+ f.write(model_card)
+ _, self.last_job = self.repo.push_to_hub(
+ commit_message=f"Training in progress epoch {epoch}", blocking=False
+ )
+
+ def on_train_end(self, logs=None):
+ # Makes sure the latest version of the model is uploaded
+ if self.last_job is not None and not self.last_job.is_done:
+ logging.info("Pushing the last epoch to the Hub, this may take a while...")
+ while not self.last_job.is_done:
+ sleep(1)
+ else:
+ self.model.save_pretrained(self.output_dir)
+ if self.tokenizer is not None:
+ self.tokenizer.save_pretrained(self.output_dir)
+ train_summary = TrainingSummary.from_keras(
+ model=self.model,
+ model_name=self.hub_model_id,
+ keras_history=self.training_history,
+ **self.model_card_args,
+ )
+ model_card = train_summary.to_model_card()
+ with (self.output_dir / "README.md").open("w") as f:
+ f.write(model_card)
+ self.repo.push_to_hub(commit_message="End of training", blocking=True)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/model_debugging_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/model_debugging_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c7b47c04fd508319e9f13511e5621260108dc2b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/model_debugging_utils.py
@@ -0,0 +1,456 @@
+# Copyright 2025 The HuggingFace Inc. team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+import json
+import os
+import re
+from contextlib import contextmanager, redirect_stdout
+from io import StringIO
+from typing import Optional
+
+from .utils import logging
+from .utils.import_utils import is_torch_available, requires
+
+
+if is_torch_available():
+ import torch
+ from safetensors.torch import save_file
+
+ _torch_distributed_available = False
+ # Note to code inspectors: this toolbox is intended for people who add models to `transformers`.
+ if torch.distributed.is_available():
+ import torch.distributed.tensor
+
+ _torch_distributed_available = True
+else:
+ _torch_distributed_available = False
+
+
+logger = logging.get_logger(__name__)
+
+
+def _is_rank_zero():
+ """Return True if rank=0 or we aren't running distributed."""
+ if not (_torch_distributed_available and torch.distributed.is_initialized()):
+ return True
+ return torch.distributed.get_rank() == 0
+
+
+MEMORY_ADDRESS_REGEX = re.compile(r"object at 0x[0-9A-Fa-f]+")
+
+
+def _sanitize_repr_for_diff(x_str: str) -> str:
+ """
+ Replace memory addresses in an object's repr with a stable placeholder
+ so that beautiful JSON diffs won't be ruined by ephemeral addresses.
+ """
+ return MEMORY_ADDRESS_REGEX.sub("object at 0xXXXXXXXX", x_str)
+
+
+def _dtensor_repr(x):
+ """Return a stable string representation for a DTensor-like object."""
+ if _is_rank_zero():
+ return f"DTensor (rank0) -> {repr(x._local_tensor)}"
+ return "DTensor(non-rank0)"
+
+
+def _serialize_tensor_like_io(
+ value, debug_path: Optional[str] = None, use_repr: bool = True, path_to_value: Optional[str] = None
+):
+ """
+ Converts Tensors and DTensors to a JSON-serializable dictionary representation.
+
+ Args:
+ value: Any Python object, often including torch Tensors, lists, dicts, etc.
+ debug_path (`str`, *optional*, defaults to `None`): Directory to dump debug JSON and SafeTensors files.
+ use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensor as the
+ `value` property in the asscoiated FULL_TENSORS.json file, or to store the full tensors in separate
+ SafeTensors file and store the relative path to that file in the `value` property in the dictionary.
+ path_to_value (`str`, *optional*, defaults to `None`): The file name for the SafeTensors file holding the full
+ tensor value if `use_repr=False`.
+
+ Returns:
+ A nested Python structure (list, dict, or sanitized string) that is safe to json.dump.
+ """
+ torch.set_printoptions(sci_mode=True)
+
+ if use_repr:
+ value_out = _repr_to_list(value)
+ elif path_to_value:
+ if not path_to_value.endswith(".safetensors"):
+ path_to_value += ".safetensors"
+
+ filepath = os.path.join(debug_path, path_to_value) if debug_path else path_to_value
+ save_file({"data": value.contiguous().detach().cpu()}, filepath)
+ value_out = f"./{path_to_value}"
+ else:
+ raise ValueError(f"{use_repr=} and {path_to_value=} cannot both be falsy.")
+
+ out = {
+ "shape": repr(value.shape),
+ "dtype": repr(value.dtype),
+ "value": value_out,
+ }
+ if value.dtype in {torch.float16, torch.float32, torch.bfloat16}:
+ out.update(
+ {
+ "mean": _sanitize_repr_for_diff(repr(value.mean())),
+ "std": _sanitize_repr_for_diff(repr(value.std())),
+ "min": _sanitize_repr_for_diff(repr(value.min())),
+ "max": _sanitize_repr_for_diff(repr(value.max())),
+ }
+ )
+ return out
+
+
+def _serialize_io(value, debug_path: Optional[str] = None, use_repr: bool = True, path_to_value: Optional[str] = None):
+ """
+ Recursively build a JSON-serializable Python structure from `value`.
+ Tensors and DTensors become either sanitized repr strings, or are saved to disk as SafeTensors files and their
+ relative paths are recorded in the returned Python structure.
+ Lists/tuples/dicts are recursed into.
+ All memory addresses are replaced with a stable placeholder.
+
+ Args:
+ value: Any Python object, often including torch Tensors, lists, dicts, etc.
+ debug_path (`str`, *optional*, defaults to `None`): Directory to dump debug JSON and SafeTensors files.
+ use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensors as the
+ `value` property in the asscoiated FULL_TENSORS.json file, or to store full tensors in separate SafeTensors
+ files and store the relative path to that file in the `value` property.
+ path_to_value (`str`, *optional*, defaults to `None`): The file name for the SafeTensors file holding the full
+ tensor value if `use_repr=False`.
+
+ Returns:
+ A nested Python structure (list, dict, or sanitized string) that is safe to json.dump.
+ """
+ if isinstance(value, (list, tuple)):
+ return [
+ _serialize_io(v, debug_path=debug_path, use_repr=use_repr, path_to_value=f"{path_to_value}_{i}")
+ for i, v in enumerate(value)
+ ]
+
+ if isinstance(value, dict):
+ return {
+ k: _serialize_io(v, debug_path=debug_path, use_repr=use_repr, path_to_value=f"{path_to_value}_{k}")
+ for k, v in value.items()
+ }
+
+ if hasattr(value, "_local_tensor"):
+ return _serialize_tensor_like_io(
+ value._local_tensor, debug_path=debug_path, use_repr=use_repr, path_to_value=path_to_value
+ )
+
+ if isinstance(value, torch.Tensor):
+ return _serialize_tensor_like_io(value, debug_path=debug_path, use_repr=use_repr, path_to_value=path_to_value)
+
+ return _sanitize_repr_for_diff(repr(value))
+
+
+def _repr_to_list(value: torch.Tensor):
+ """
+ Converts a tensor into a sanitized multi-line string representation.
+
+ Args:
+ value (`torch.Tensor`): The tensor to represent.
+
+ Returns:
+ `list[str]`: List of string lines representing the tensor.
+ """
+ torch.set_printoptions(sci_mode=True, linewidth=120)
+ with StringIO() as buf, redirect_stdout(buf):
+ print(value) # to redirected stdout to avoid line splits
+ raw = buf.getvalue()
+ return _sanitize_repr_for_diff(raw).splitlines()
+
+
+def prune_outputs_if_children(node):
+ # if there are children, remove this node's "outputs"
+ # so we only see outputs at the leaf level
+ if node.get("children"):
+ node.pop("outputs", None)
+ for child in node["children"]:
+ prune_outputs_if_children(child)
+
+
+LAYER_SUFFIX_RE = re.compile(r"(.*)\.(\d+)$") # should be generic enough, ends with a number
+
+
+def is_layer_block(node):
+ """
+ Checks whether a node represents a layer block with submodules.
+
+ Args:
+ node (`dict`): A node from the call tree.
+
+ Returns:
+ `bool`: Whether the node is a layer block.
+ """
+ match = LAYER_SUFFIX_RE.match(node.get("module_path", ""))
+ if not match or not node.get("children"):
+ return False
+ number = match.group(2)
+ return any(f".{number}." in child.get("module_path", "") for child in node["children"])
+
+
+def prune_intermediate_layers(node):
+ """
+ Recursively removes intermediate layers from the tree to improve readability.
+ Keeps at least the first and last layers if many consecutive layers are present.
+
+ Args:
+ node (`dict`): The root or subnode to prune recursively.
+ """
+ if not node.get("children"):
+ return
+ layer_blocks = [(i, child) for i, child in enumerate(node["children"]) if is_layer_block(child)]
+
+ if len(layer_blocks) > 2:
+ to_remove = [i for i, _ in layer_blocks[1:-1]]
+ node["children"] = [child for i, child in enumerate(node["children"]) if i not in to_remove]
+
+ for child in node["children"]:
+ prune_intermediate_layers(child)
+
+
+def log_model_debug_trace(debug_path: Optional[str], model):
+ if debug_path:
+ try:
+ os.makedirs(debug_path, exist_ok=True)
+ base = os.path.join(debug_path, model._debugger_module_dump_name + "_debug_tree")
+ except Exception as e:
+ raise ValueError(f"Unexpected or existing debug_path={debug_path}.") from e
+ else:
+ base = model._debugger_module_dump_name + "_debug_tree"
+
+ logger.info(f"Writing model trace at {base}.json")
+ full_path = base + "_FULL_TENSORS.json"
+ summary_path = base + "_SUMMARY.json"
+
+ prune_outputs_if_children(model._call_tree)
+
+ with open(full_path, "w") as f:
+ json.dump(model._call_tree, f, indent=2)
+
+ # summary-only version for readability - traversing the tree again #TODO optimize?
+ def strip_values(node):
+ def clean(val):
+ if isinstance(val, dict):
+ val.pop("value", None)
+ for v in val.values():
+ clean(v)
+ elif isinstance(val, list):
+ for item in val:
+ clean(item)
+
+ clean(node.get("inputs", {}))
+ clean(node.get("outputs", {}))
+
+ for child in node.get("children", []):
+ strip_values(child)
+
+ tree_copy = json.loads(json.dumps(model._call_tree)) # deep copy
+ strip_values(tree_copy)
+
+ with open(summary_path, "w") as f:
+ json.dump(tree_copy, f, indent=2)
+
+
+def _attach_debugger_logic(
+ model,
+ debug_path: str = ".",
+ do_prune_layers: bool = True,
+ use_repr: bool = True,
+):
+ """
+ Attaches a debugging wrapper to every module in the model.
+
+ This records structured inputs and outputs during the forward pass into a call tree.
+
+ Args:
+ model (`PreTrainedModel`, `nn.Module`): Model to wrap.
+ debug_path (`str`): Optional directory to dump debug JSON files.
+ do_prune_layers (`bool`, *optional*, defaults to `True`): Whether to prune intermediate layers.
+ use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensors as the
+ `value` property in the associated FULL_TENSORS.json file, or to store full tensors in separate SafeTensors
+ files and store the relative path to that file in the `value` property.
+ """
+ class_name = model.__class__.__name__
+
+ # Prepare data structures on the model object
+ model._call_tree = {"module_path": class_name, "inputs": None, "outputs": None, "children": []}
+ model._debugger_model_call_stack = []
+ model._debugger_module_dump_name = class_name # used for final JSON filename
+
+ if debug_path:
+ try:
+ os.makedirs(debug_path, exist_ok=True)
+ except Exception as e:
+ raise ValueError(f"Unexpected or existing debug_path={debug_path}.") from e
+
+ def wrap_forward(module, full_path):
+ orig_forward = module.forward
+
+ @functools.wraps(orig_forward)
+ def wrapped_forward(*inps, **kws):
+ if _is_rank_zero():
+ dict_inputs = {"args": inps, "kwargs": kws}
+ dict_inputs = {k: dict_inputs[k] for k in dict_inputs if len(dict_inputs[k]) > 0}
+ node = {
+ "module_path": full_path,
+ "inputs": _serialize_io(
+ dict_inputs,
+ debug_path=debug_path,
+ use_repr=use_repr,
+ path_to_value=f"{full_path}_inputs",
+ ),
+ "outputs": None,
+ "children": [],
+ }
+ model._debugger_model_call_stack.append(node)
+ with torch.no_grad():
+ out = orig_forward(*inps, **kws)
+
+ if _is_rank_zero():
+ if sum(1 for _ in module.named_children()) > 0:
+ node["outputs"] = None
+ else:
+ node["outputs"] = _serialize_io(
+ out,
+ debug_path=debug_path,
+ use_repr=use_repr,
+ path_to_value=f"{full_path}_outputs",
+ )
+
+ finished = model._debugger_model_call_stack.pop()
+ # prune empty vertices here as well (mostly empty children nodes)
+ if not finished["children"]:
+ finished.pop("children")
+
+ if model._debugger_model_call_stack:
+ model._debugger_model_call_stack[-1]["children"].append(finished)
+ return out
+
+ module.forward = wrapped_forward
+
+ # wrap all submodules
+ for name, submodule in model.named_modules():
+ if name == "":
+ continue
+ wrap_forward(submodule, f"{class_name}.{name}")
+
+ # wrap top-level forward
+ real_top_forward = model.forward
+
+ @functools.wraps(real_top_forward)
+ def top_wrapped_forward(*inps, **kws):
+ if _is_rank_zero():
+ top_node = {
+ "module_path": f"{class_name} (top-level)",
+ "inputs": _serialize_io(
+ {"args": inps, "kwargs": kws},
+ debug_path=debug_path,
+ use_repr=use_repr,
+ path_to_value=f"{class_name}_inputs",
+ ),
+ "outputs": None,
+ "children": [],
+ }
+ model._debugger_model_call_stack.append(top_node)
+
+ out = real_top_forward(*inps, **kws)
+ if _is_rank_zero() and model._debugger_model_call_stack:
+ top_node["outputs"] = _serialize_io(
+ out,
+ debug_path=debug_path,
+ use_repr=use_repr,
+ path_to_value=f"{class_name}_outputs",
+ )
+ finished = model._debugger_model_call_stack.pop()
+ model._call_tree["inputs"] = finished["inputs"]
+ model._call_tree["outputs"] = finished["outputs"]
+ model._call_tree["children"] = finished["children"]
+ # prune empty stuff for visibility
+ [model._call_tree.pop(k, None) for k in list(model._call_tree.keys()) if not model._call_tree[k]]
+
+ # prune layers that are not 0 or last
+ if do_prune_layers:
+ prune_intermediate_layers(model._call_tree)
+ # Write final JSON trace here
+ log_model_debug_trace(debug_path=debug_path, model=model)
+ return out
+
+ model.forward = top_wrapped_forward
+
+
+@requires(backends=("torch",))
+@contextmanager
+def model_addition_debugger_context(
+ model,
+ debug_path: Optional[str] = None,
+ do_prune_layers: bool = True,
+ use_repr: bool = True,
+):
+ """
+ # Model addition debugger - context manager for model adders
+ This context manager is a power user tool intended for model adders.
+
+ It tracks all forward calls within a model forward and logs a slice of each input and output on a nested JSON file.
+ If `use_repr=True` (the default), the JSON file will record a `repr()`-ized version of the tensors as a list of
+ strings. If `use_repr=False`, the full tensors will be stored in separate SafeTensors files and the JSON file will
+ provide a relative path to that file.
+
+ To note, this context manager enforces `torch.no_grad()`.
+
+ ## Usage
+
+ add the context manager to a model to debug
+
+ ```python
+ import torch
+
+ from PIL import Image
+ from transformers import LlavaProcessor, LlavaForConditionalGeneration, model_addition_debugger_context
+
+ torch.random.manual_seed(673)
+
+ # load pretrained model and processor
+ model_id = "llava-hf/llava-1.5-7b-hf"
+ processor = LlavaProcessor.from_pretrained(model_id)
+ model = LlavaForConditionalGeneration.from_pretrained(model_id)
+
+ # create random image input
+ random_image = Image.fromarray(torch.randint(0, 256, (224, 224, 3), dtype=torch.uint8).numpy())
+
+ # prompt
+ prompt = "Describe this image."
+
+ # process inputs
+ inputs = processor(text=prompt, images=random_image, return_tensors="pt")
+
+ # call forward method (not .generate!)
+ with model_addition_debugger_context(model, debug_path="Your_debug_path", do_prune_layers=False):
+ output = model.forward(**inputs)
+ ```
+
+ """
+ orig_forwards = {m: m.forward for _, m in model.named_modules()}
+ orig_forwards[model] = model.forward
+ _attach_debugger_logic(model, debug_path, do_prune_layers, use_repr)
+ try:
+ yield model
+ finally:
+ for module_instance, forward_method in orig_forwards.items():
+ module_instance.forward = forward_method
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modelcard.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modelcard.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd3a0b4017334752ecec35a634ead9eae96d462b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modelcard.py
@@ -0,0 +1,914 @@
+# Copyright 2018 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Configuration base class and utilities."""
+
+import copy
+import json
+import os
+import warnings
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Optional, Union
+
+import requests
+import yaml
+from huggingface_hub import model_info
+from huggingface_hub.errors import OfflineModeIsEnabled
+from huggingface_hub.utils import HFValidationError
+
+from . import __version__
+from .models.auto.modeling_auto import (
+ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
+ MODEL_FOR_CTC_MAPPING_NAMES,
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
+ MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
+ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
+ MODEL_FOR_MASKED_LM_MAPPING_NAMES,
+ MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
+ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
+ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
+ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
+ MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
+)
+from .training_args import ParallelMode
+from .utils import (
+ MODEL_CARD_NAME,
+ cached_file,
+ is_datasets_available,
+ is_offline_mode,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+ logging,
+)
+
+
+TASK_MAPPING = {
+ "text-generation": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
+ "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
+ "image-segmentation": MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
+ "fill-mask": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
+ "object-detection": MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
+ "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
+ "text2text-generation": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
+ "text-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
+ "table-question-answering": MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
+ "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
+ "audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
+ "automatic-speech-recognition": {**MODEL_FOR_CTC_MAPPING_NAMES, **MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES},
+ "zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
+ "image-text-to-text": MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
+}
+
+logger = logging.get_logger(__name__)
+
+
+class ModelCard:
+ r"""
+ Structured Model Card class. Store model card as well as methods for loading/downloading/saving model cards.
+
+ Please read the following paper for details and explanation on the sections: "Model Cards for Model Reporting" by
+ Margaret Mitchell, Simone Wu, Andrew Zaldivar, Parker Barnes, Lucy Vasserman, Ben Hutchinson, Elena Spitzer,
+ Inioluwa Deborah Raji and Timnit Gebru for the proposal behind model cards. Link: https://huggingface.co/papers/1810.03993
+
+ Note: A model card can be loaded and saved to disk.
+ """
+
+ def __init__(self, **kwargs):
+ warnings.warn(
+ "The class `ModelCard` is deprecated and will be removed in version 5 of Transformers", FutureWarning
+ )
+ # Recommended attributes from https://huggingface.co/papers/1810.03993 (see papers)
+ self.model_details = kwargs.pop("model_details", {})
+ self.intended_use = kwargs.pop("intended_use", {})
+ self.factors = kwargs.pop("factors", {})
+ self.metrics = kwargs.pop("metrics", {})
+ self.evaluation_data = kwargs.pop("evaluation_data", {})
+ self.training_data = kwargs.pop("training_data", {})
+ self.quantitative_analyses = kwargs.pop("quantitative_analyses", {})
+ self.ethical_considerations = kwargs.pop("ethical_considerations", {})
+ self.caveats_and_recommendations = kwargs.pop("caveats_and_recommendations", {})
+
+ # Open additional attributes
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ logger.error(f"Can't set {key} with value {value} for {self}")
+ raise err
+
+ def save_pretrained(self, save_directory_or_file):
+ """Save a model card object to the directory or file `save_directory_or_file`."""
+ if os.path.isdir(save_directory_or_file):
+ # If we save using the predefined names, we can load using `from_pretrained`
+ output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME)
+ else:
+ output_model_card_file = save_directory_or_file
+
+ self.to_json_file(output_model_card_file)
+ logger.info(f"Model card saved in {output_model_card_file}")
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+ r"""
+ Instantiate a [`ModelCard`] from a pre-trained model model card.
+
+ Parameters:
+ pretrained_model_name_or_path: either:
+
+ - a string, the *model id* of a pretrained model card hosted inside a model repo on huggingface.co.
+ - a path to a *directory* containing a model card file saved using the [`~ModelCard.save_pretrained`]
+ method, e.g.: `./my_model_directory/`.
+ - a path or url to a saved model card JSON *file*, e.g.: `./my_model_directory/modelcard.json`.
+
+ cache_dir: (*optional*) string:
+ Path to a directory in which a downloaded pre-trained model card should be cached if the standard cache
+ should not be used.
+
+ kwargs: (*optional*) dict: key/value pairs with which to update the ModelCard object after loading.
+
+ - The values in kwargs of any keys which are model card attributes will be used to override the loaded
+ values.
+ - Behavior concerning key/value pairs whose keys are *not* model card attributes is controlled by the
+ *return_unused_kwargs* keyword parameter.
+
+ proxies: (*optional*) dict, default None:
+ A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.
+
+ return_unused_kwargs: (*optional*) bool:
+
+ - If False, then this function returns just the final model card object.
+ - If True, then this functions returns a tuple *(model card, unused_kwargs)* where *unused_kwargs* is a
+ dictionary consisting of the key/value pairs whose keys are not model card attributes: ie the part of
+ kwargs which has not been used to update *ModelCard* and is otherwise ignored.
+
+ Examples:
+
+ ```python
+ # Download model card from huggingface.co and cache.
+ modelcard = ModelCard.from_pretrained("google-bert/bert-base-uncased")
+ # Model card was saved using *save_pretrained('./test/saved_model/')*
+ modelcard = ModelCard.from_pretrained("./test/saved_model/")
+ modelcard = ModelCard.from_pretrained("./test/saved_model/modelcard.json")
+ modelcard = ModelCard.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
+ ```"""
+ cache_dir = kwargs.pop("cache_dir", None)
+ proxies = kwargs.pop("proxies", None)
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
+ from_pipeline = kwargs.pop("_from_pipeline", None)
+
+ user_agent = {"file_type": "model_card"}
+ if from_pipeline is not None:
+ user_agent["using_pipeline"] = from_pipeline
+
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if os.path.isfile(pretrained_model_name_or_path):
+ resolved_model_card_file = pretrained_model_name_or_path
+ is_local = True
+ else:
+ try:
+ # Load from URL or cache if already cached
+ resolved_model_card_file = cached_file(
+ pretrained_model_name_or_path,
+ filename=MODEL_CARD_NAME,
+ cache_dir=cache_dir,
+ proxies=proxies,
+ user_agent=user_agent,
+ )
+ if is_local:
+ logger.info(f"loading model card file {resolved_model_card_file}")
+ else:
+ logger.info(f"loading model card file {MODEL_CARD_NAME} from cache at {resolved_model_card_file}")
+ # Load model card
+ modelcard = cls.from_json_file(resolved_model_card_file)
+
+ except (OSError, json.JSONDecodeError):
+ # We fall back on creating an empty model card
+ modelcard = cls()
+
+ # Update model card with kwargs if needed
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(modelcard, key):
+ setattr(modelcard, key, value)
+ to_remove.append(key)
+ for key in to_remove:
+ kwargs.pop(key, None)
+
+ logger.info(f"Model card: {modelcard}")
+ if return_unused_kwargs:
+ return modelcard, kwargs
+ else:
+ return modelcard
+
+ @classmethod
+ def from_dict(cls, json_object):
+ """Constructs a `ModelCard` from a Python dictionary of parameters."""
+ return cls(**json_object)
+
+ @classmethod
+ def from_json_file(cls, json_file):
+ """Constructs a `ModelCard` from a json file of parameters."""
+ with open(json_file, encoding="utf-8") as reader:
+ text = reader.read()
+ dict_obj = json.loads(text)
+ return cls(**dict_obj)
+
+ def __eq__(self, other):
+ return self.__dict__ == other.__dict__
+
+ def __repr__(self):
+ return str(self.to_json_string())
+
+ def to_dict(self):
+ """Serializes this instance to a Python dictionary."""
+ output = copy.deepcopy(self.__dict__)
+ return output
+
+ def to_json_string(self):
+ """Serializes this instance to a JSON string."""
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path):
+ """Save this instance to a json file."""
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
+
+
+AUTOGENERATED_TRAINER_COMMENT = """
+
+"""
+
+AUTOGENERATED_KERAS_COMMENT = """
+
+"""
+
+
+TASK_TAG_TO_NAME_MAPPING = {
+ "fill-mask": "Masked Language Modeling",
+ "image-classification": "Image Classification",
+ "image-segmentation": "Image Segmentation",
+ "multiple-choice": "Multiple Choice",
+ "object-detection": "Object Detection",
+ "question-answering": "Question Answering",
+ "summarization": "Summarization",
+ "table-question-answering": "Table Question Answering",
+ "text-classification": "Text Classification",
+ "text-generation": "Causal Language Modeling",
+ "text2text-generation": "Sequence-to-sequence Language Modeling",
+ "token-classification": "Token Classification",
+ "translation": "Translation",
+ "zero-shot-classification": "Zero Shot Classification",
+ "automatic-speech-recognition": "Automatic Speech Recognition",
+ "audio-classification": "Audio Classification",
+}
+
+
+METRIC_TAGS = [
+ "accuracy",
+ "bleu",
+ "f1",
+ "matthews_correlation",
+ "pearsonr",
+ "precision",
+ "recall",
+ "rouge",
+ "sacrebleu",
+ "spearmanr",
+ "wer",
+]
+
+
+def _listify(obj):
+ if obj is None:
+ return []
+ elif isinstance(obj, str):
+ return [obj]
+ else:
+ return obj
+
+
+def _insert_values_as_list(metadata, name, values):
+ if values is None:
+ return metadata
+ if isinstance(values, str):
+ values = [values]
+ values = [v for v in values if v is not None]
+ if len(values) == 0:
+ return metadata
+ metadata[name] = values
+ return metadata
+
+
+def infer_metric_tags_from_eval_results(eval_results):
+ if eval_results is None:
+ return {}
+ result = {}
+ for key in eval_results:
+ if key.lower().replace(" ", "_") in METRIC_TAGS:
+ result[key.lower().replace(" ", "_")] = key
+ elif key.lower() == "rouge1":
+ result["rouge"] = key
+ return result
+
+
+def _insert_value(metadata, name, value):
+ if value is None:
+ return metadata
+ metadata[name] = value
+ return metadata
+
+
+def is_hf_dataset(dataset):
+ if not is_datasets_available():
+ return False
+
+ from datasets import Dataset, IterableDataset
+
+ return isinstance(dataset, (Dataset, IterableDataset))
+
+
+def _get_mapping_values(mapping):
+ result = []
+ for v in mapping.values():
+ if isinstance(v, (tuple, list)):
+ result += list(v)
+ else:
+ result.append(v)
+ return result
+
+
+@dataclass
+class TrainingSummary:
+ model_name: str
+ language: Optional[Union[str, list[str]]] = None
+ license: Optional[str] = None
+ tags: Optional[Union[str, list[str]]] = None
+ finetuned_from: Optional[str] = None
+ tasks: Optional[Union[str, list[str]]] = None
+ dataset: Optional[Union[str, list[str]]] = None
+ dataset_tags: Optional[Union[str, list[str]]] = None
+ dataset_args: Optional[Union[str, list[str]]] = None
+ dataset_metadata: Optional[dict[str, Any]] = None
+ eval_results: Optional[dict[str, float]] = None
+ eval_lines: Optional[list[str]] = None
+ hyperparameters: Optional[dict[str, Any]] = None
+ source: Optional[str] = "trainer"
+
+ def __post_init__(self):
+ # Infer default license from the checkpoint used, if possible.
+ if (
+ self.license is None
+ and not is_offline_mode()
+ and self.finetuned_from is not None
+ and len(self.finetuned_from) > 0
+ ):
+ try:
+ info = model_info(self.finetuned_from)
+ for tag in info.tags:
+ if tag.startswith("license:"):
+ self.license = tag[8:]
+ except (
+ requests.exceptions.HTTPError,
+ requests.exceptions.ConnectionError,
+ HFValidationError,
+ OfflineModeIsEnabled,
+ ):
+ pass
+
+ def create_model_index(self, metric_mapping):
+ model_index = {"name": self.model_name}
+
+ # Dataset mapping tag -> name
+ dataset_names = _listify(self.dataset)
+ dataset_tags = _listify(self.dataset_tags)
+ dataset_args = _listify(self.dataset_args)
+ dataset_metadata = _listify(self.dataset_metadata)
+ if len(dataset_args) < len(dataset_tags):
+ dataset_args = dataset_args + [None] * (len(dataset_tags) - len(dataset_args))
+ dataset_mapping = dict(zip(dataset_tags, dataset_names))
+ dataset_arg_mapping = dict(zip(dataset_tags, dataset_args))
+ dataset_metadata_mapping = dict(zip(dataset_tags, dataset_metadata))
+
+ task_mapping = {
+ task: TASK_TAG_TO_NAME_MAPPING[task] for task in _listify(self.tasks) if task in TASK_TAG_TO_NAME_MAPPING
+ }
+
+ model_index["results"] = []
+
+ if len(task_mapping) == 0 and len(dataset_mapping) == 0:
+ return [model_index]
+ if len(task_mapping) == 0:
+ task_mapping = {None: None}
+ if len(dataset_mapping) == 0:
+ dataset_mapping = {None: None}
+
+ # One entry per dataset and per task
+ all_possibilities = [(task_tag, ds_tag) for task_tag in task_mapping for ds_tag in dataset_mapping]
+ for task_tag, ds_tag in all_possibilities:
+ result = {}
+ if task_tag is not None:
+ result["task"] = {"name": task_mapping[task_tag], "type": task_tag}
+
+ if ds_tag is not None:
+ metadata = dataset_metadata_mapping.get(ds_tag, {})
+ result["dataset"] = {
+ "name": dataset_mapping[ds_tag],
+ "type": ds_tag,
+ **metadata,
+ }
+ if dataset_arg_mapping[ds_tag] is not None:
+ result["dataset"]["args"] = dataset_arg_mapping[ds_tag]
+
+ if len(metric_mapping) > 0:
+ result["metrics"] = []
+ for metric_tag, metric_name in metric_mapping.items():
+ result["metrics"].append(
+ {
+ "name": metric_name,
+ "type": metric_tag,
+ "value": self.eval_results[metric_name],
+ }
+ )
+
+ # Remove partial results to avoid the model card being rejected.
+ if "task" in result and "dataset" in result and "metrics" in result:
+ model_index["results"].append(result)
+ else:
+ logger.info(f"Dropping the following result as it does not have all the necessary fields:\n{result}")
+
+ return [model_index]
+
+ def create_metadata(self):
+ metric_mapping = infer_metric_tags_from_eval_results(self.eval_results)
+
+ metadata = {}
+ metadata = _insert_value(metadata, "library_name", "transformers")
+ metadata = _insert_values_as_list(metadata, "language", self.language)
+ metadata = _insert_value(metadata, "license", self.license)
+ if self.finetuned_from is not None and isinstance(self.finetuned_from, str) and len(self.finetuned_from) > 0:
+ metadata = _insert_value(metadata, "base_model", self.finetuned_from)
+ metadata = _insert_values_as_list(metadata, "tags", self.tags)
+ metadata = _insert_values_as_list(metadata, "datasets", self.dataset_tags)
+ metadata = _insert_values_as_list(metadata, "metrics", list(metric_mapping.keys()))
+ metadata["model-index"] = self.create_model_index(metric_mapping)
+
+ return metadata
+
+ def to_model_card(self):
+ model_card = ""
+
+ metadata = yaml.dump(self.create_metadata(), sort_keys=False)
+ if len(metadata) > 0:
+ model_card = f"---\n{metadata}---\n"
+
+ # Now the model card for realsies.
+ if self.source == "trainer":
+ model_card += AUTOGENERATED_TRAINER_COMMENT
+ else:
+ model_card += AUTOGENERATED_KERAS_COMMENT
+
+ model_card += f"\n# {self.model_name}\n\n"
+
+ if self.finetuned_from is None:
+ model_card += "This model was trained from scratch on "
+ else:
+ model_card += (
+ "This model is a fine-tuned version of"
+ f" [{self.finetuned_from}](https://huggingface.co/{self.finetuned_from}) on "
+ )
+
+ if self.dataset is None or (isinstance(self.dataset, list) and len(self.dataset) == 0):
+ model_card += "an unknown dataset."
+ else:
+ if isinstance(self.dataset, str):
+ model_card += f"the {self.dataset} dataset."
+ elif isinstance(self.dataset, (tuple, list)) and len(self.dataset) == 1:
+ model_card += f"the {self.dataset[0]} dataset."
+ else:
+ model_card += (
+ ", ".join([f"the {ds}" for ds in self.dataset[:-1]]) + f" and the {self.dataset[-1]} datasets."
+ )
+
+ if self.eval_results is not None:
+ model_card += "\nIt achieves the following results on the evaluation set:\n"
+ model_card += "\n".join([f"- {name}: {_maybe_round(value)}" for name, value in self.eval_results.items()])
+ model_card += "\n"
+
+ model_card += "\n## Model description\n\nMore information needed\n"
+ model_card += "\n## Intended uses & limitations\n\nMore information needed\n"
+ model_card += "\n## Training and evaluation data\n\nMore information needed\n"
+
+ model_card += "\n## Training procedure\n"
+ model_card += "\n### Training hyperparameters\n"
+ if self.hyperparameters is not None:
+ model_card += "\nThe following hyperparameters were used during training:\n"
+ model_card += "\n".join([f"- {name}: {value}" for name, value in self.hyperparameters.items()])
+ model_card += "\n"
+ else:
+ model_card += "\nMore information needed\n"
+
+ if self.eval_lines is not None:
+ model_card += "\n### Training results\n\n"
+ model_card += make_markdown_table(self.eval_lines)
+ model_card += "\n"
+
+ model_card += "\n### Framework versions\n\n"
+ model_card += f"- Transformers {__version__}\n"
+
+ if self.source == "trainer" and is_torch_available():
+ import torch
+
+ model_card += f"- Pytorch {torch.__version__}\n"
+ elif self.source == "keras" and is_tf_available():
+ import tensorflow as tf
+
+ model_card += f"- TensorFlow {tf.__version__}\n"
+ if is_datasets_available():
+ import datasets
+
+ model_card += f"- Datasets {datasets.__version__}\n"
+ if is_tokenizers_available():
+ import tokenizers
+
+ model_card += f"- Tokenizers {tokenizers.__version__}\n"
+
+ return model_card
+
+ @classmethod
+ def from_trainer(
+ cls,
+ trainer,
+ language=None,
+ license=None,
+ tags=None,
+ model_name=None,
+ finetuned_from=None,
+ tasks=None,
+ dataset_tags=None,
+ dataset_metadata=None,
+ dataset=None,
+ dataset_args=None,
+ ):
+ # Infer default from dataset
+ one_dataset = trainer.eval_dataset if trainer.eval_dataset is not None else trainer.train_dataset
+ if is_hf_dataset(one_dataset) and (dataset_tags is None or dataset_args is None or dataset_metadata is None):
+ default_tag = one_dataset.builder_name
+ # Those are not real datasets from the Hub so we exclude them.
+ if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
+ if dataset_metadata is None:
+ dataset_metadata = [{"config": one_dataset.config_name, "split": str(one_dataset.split)}]
+ if dataset_tags is None:
+ dataset_tags = [default_tag]
+ if dataset_args is None:
+ dataset_args = [one_dataset.config_name]
+
+ if dataset is None and dataset_tags is not None:
+ dataset = dataset_tags
+
+ # Infer default finetuned_from
+ if (
+ finetuned_from is None
+ and hasattr(trainer.model.config, "_name_or_path")
+ and not os.path.isdir(trainer.model.config._name_or_path)
+ ):
+ finetuned_from = trainer.model.config._name_or_path
+
+ # Infer default task tag:
+ if tasks is None:
+ model_class_name = trainer.model.__class__.__name__
+ for task, mapping in TASK_MAPPING.items():
+ if model_class_name in _get_mapping_values(mapping):
+ tasks = task
+
+ if model_name is None:
+ model_name = Path(trainer.args.output_dir).name
+ if len(model_name) == 0:
+ model_name = finetuned_from
+
+ # Add `generated_from_trainer` to the tags
+ if tags is None:
+ tags = ["generated_from_trainer"]
+ elif isinstance(tags, str) and tags != "generated_from_trainer":
+ tags = [tags, "generated_from_trainer"]
+ elif "generated_from_trainer" not in tags:
+ tags.append("generated_from_trainer")
+
+ _, eval_lines, eval_results = parse_log_history(trainer.state.log_history)
+ hyperparameters = extract_hyperparameters_from_trainer(trainer)
+
+ return cls(
+ language=language,
+ license=license,
+ tags=tags,
+ model_name=model_name,
+ finetuned_from=finetuned_from,
+ tasks=tasks,
+ dataset=dataset,
+ dataset_tags=dataset_tags,
+ dataset_args=dataset_args,
+ dataset_metadata=dataset_metadata,
+ eval_results=eval_results,
+ eval_lines=eval_lines,
+ hyperparameters=hyperparameters,
+ )
+
+ @classmethod
+ def from_keras(
+ cls,
+ model,
+ model_name,
+ keras_history=None,
+ language=None,
+ license=None,
+ tags=None,
+ finetuned_from=None,
+ tasks=None,
+ dataset_tags=None,
+ dataset=None,
+ dataset_args=None,
+ ):
+ # Infer default from dataset
+ if dataset is not None:
+ if is_hf_dataset(dataset) and (dataset_tags is None or dataset_args is None):
+ default_tag = dataset.builder_name
+ # Those are not real datasets from the Hub so we exclude them.
+ if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
+ if dataset_tags is None:
+ dataset_tags = [default_tag]
+ if dataset_args is None:
+ dataset_args = [dataset.config_name]
+
+ if dataset is None and dataset_tags is not None:
+ dataset = dataset_tags
+
+ # Infer default finetuned_from
+ if (
+ finetuned_from is None
+ and hasattr(model.config, "_name_or_path")
+ and not os.path.isdir(model.config._name_or_path)
+ ):
+ finetuned_from = model.config._name_or_path
+
+ # Infer default task tag:
+ if tasks is None:
+ model_class_name = model.__class__.__name__
+ for task, mapping in TASK_MAPPING.items():
+ if model_class_name in _get_mapping_values(mapping):
+ tasks = task
+
+ # Add `generated_from_keras_callback` to the tags
+ if tags is None:
+ tags = ["generated_from_keras_callback"]
+ elif isinstance(tags, str) and tags != "generated_from_keras_callback":
+ tags = [tags, "generated_from_keras_callback"]
+ elif "generated_from_keras_callback" not in tags:
+ tags.append("generated_from_keras_callback")
+
+ if keras_history is not None:
+ _, eval_lines, eval_results = parse_keras_history(keras_history)
+ else:
+ eval_lines = []
+ eval_results = {}
+ hyperparameters = extract_hyperparameters_from_keras(model)
+
+ return cls(
+ language=language,
+ license=license,
+ tags=tags,
+ model_name=model_name,
+ finetuned_from=finetuned_from,
+ tasks=tasks,
+ dataset_tags=dataset_tags,
+ dataset=dataset,
+ dataset_args=dataset_args,
+ eval_results=eval_results,
+ eval_lines=eval_lines,
+ hyperparameters=hyperparameters,
+ source="keras",
+ )
+
+
+def parse_keras_history(logs):
+ """
+ Parse the `logs` of either a `keras.History` object returned by `model.fit()` or an accumulated logs `dict`
+ passed to the `PushToHubCallback`. Returns lines and logs compatible with those returned by `parse_log_history`.
+ """
+ if hasattr(logs, "history"):
+ # This looks like a `History` object
+ if not hasattr(logs, "epoch"):
+ # This history looks empty, return empty results
+ return None, [], {}
+ logs.history["epoch"] = logs.epoch
+ logs = logs.history
+ else:
+ # Training logs is a list of dicts, let's invert it to a dict of lists to match a History object
+ logs = {log_key: [single_dict[log_key] for single_dict in logs] for log_key in logs[0]}
+
+ lines = []
+ for i in range(len(logs["epoch"])):
+ epoch_dict = {log_key: log_value_list[i] for log_key, log_value_list in logs.items()}
+ values = {}
+ for k, v in epoch_dict.items():
+ if k.startswith("val_"):
+ k = "validation_" + k[4:]
+ elif k != "epoch":
+ k = "train_" + k
+ splits = k.split("_")
+ name = " ".join([part.capitalize() for part in splits])
+ values[name] = v
+ lines.append(values)
+
+ eval_results = lines[-1]
+
+ return logs, lines, eval_results
+
+
+def parse_log_history(log_history):
+ """
+ Parse the `log_history` of a Trainer to get the intermediate and final evaluation results.
+ """
+ idx = 0
+ while idx < len(log_history) and "train_runtime" not in log_history[idx]:
+ idx += 1
+
+ # If there are no training logs
+ if idx == len(log_history):
+ idx -= 1
+ while idx >= 0 and "eval_loss" not in log_history[idx]:
+ idx -= 1
+
+ if idx >= 0:
+ return None, None, log_history[idx]
+ else:
+ return None, None, None
+
+ # From now one we can assume we have training logs:
+ train_log = log_history[idx]
+ lines = []
+ training_loss = "No log"
+ for i in range(idx):
+ if "loss" in log_history[i]:
+ training_loss = log_history[i]["loss"]
+ if "eval_loss" in log_history[i]:
+ metrics = log_history[i].copy()
+ _ = metrics.pop("total_flos", None)
+ epoch = metrics.pop("epoch", None)
+ step = metrics.pop("step", None)
+ _ = metrics.pop("eval_runtime", None)
+ _ = metrics.pop("eval_samples_per_second", None)
+ _ = metrics.pop("eval_steps_per_second", None)
+ _ = metrics.pop("eval_jit_compilation_time", None)
+ values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step}
+ for k, v in metrics.items():
+ if k == "eval_loss":
+ values["Validation Loss"] = v
+ else:
+ splits = k.split("_")
+ name = " ".join([part.capitalize() for part in splits[1:]])
+ values[name] = v
+ lines.append(values)
+
+ idx = len(log_history) - 1
+ while idx >= 0 and "eval_loss" not in log_history[idx]:
+ idx -= 1
+
+ if idx > 0:
+ eval_results = {}
+ for key, value in log_history[idx].items():
+ key = key.removeprefix("eval_")
+ if key not in ["runtime", "samples_per_second", "steps_per_second", "epoch", "step"]:
+ camel_cased_key = " ".join([part.capitalize() for part in key.split("_")])
+ eval_results[camel_cased_key] = value
+ return train_log, lines, eval_results
+ else:
+ return train_log, lines, None
+
+
+def extract_hyperparameters_from_keras(model):
+ from .modeling_tf_utils import keras
+
+ hyperparameters = {}
+ if hasattr(model, "optimizer") and model.optimizer is not None:
+ hyperparameters["optimizer"] = model.optimizer.get_config()
+ else:
+ hyperparameters["optimizer"] = None
+ hyperparameters["training_precision"] = keras.mixed_precision.global_policy().name
+
+ return hyperparameters
+
+
+def _maybe_round(v, decimals=4):
+ if isinstance(v, float) and len(str(v).split(".")) > 1 and len(str(v).split(".")[1]) > decimals:
+ return f"{v:.{decimals}f}"
+ return str(v)
+
+
+def _regular_table_line(values, col_widths):
+ values_with_space = [f"| {v}" + " " * (w - len(v) + 1) for v, w in zip(values, col_widths)]
+ return "".join(values_with_space) + "|\n"
+
+
+def _second_table_line(col_widths):
+ values = ["|:" + "-" * w + ":" for w in col_widths]
+ return "".join(values) + "|\n"
+
+
+def make_markdown_table(lines):
+ """
+ Create a nice Markdown table from the results in `lines`.
+ """
+ if lines is None or len(lines) == 0:
+ return ""
+ col_widths = {key: len(str(key)) for key in lines[0]}
+ for line in lines:
+ for key, value in line.items():
+ if col_widths[key] < len(_maybe_round(value)):
+ col_widths[key] = len(_maybe_round(value))
+
+ table = _regular_table_line(list(lines[0].keys()), list(col_widths.values()))
+ table += _second_table_line(list(col_widths.values()))
+ for line in lines:
+ table += _regular_table_line([_maybe_round(v) for v in line.values()], list(col_widths.values()))
+ return table
+
+
+_TRAINING_ARGS_KEYS = [
+ "learning_rate",
+ "train_batch_size",
+ "eval_batch_size",
+ "seed",
+]
+
+
+def extract_hyperparameters_from_trainer(trainer):
+ hyperparameters = {k: getattr(trainer.args, k) for k in _TRAINING_ARGS_KEYS}
+
+ if trainer.args.parallel_mode not in [ParallelMode.NOT_PARALLEL, ParallelMode.NOT_DISTRIBUTED]:
+ hyperparameters["distributed_type"] = (
+ "multi-GPU" if trainer.args.parallel_mode == ParallelMode.DISTRIBUTED else trainer.args.parallel_mode.value
+ )
+ if trainer.args.world_size > 1:
+ hyperparameters["num_devices"] = trainer.args.world_size
+ if trainer.args.gradient_accumulation_steps > 1:
+ hyperparameters["gradient_accumulation_steps"] = trainer.args.gradient_accumulation_steps
+
+ total_train_batch_size = (
+ trainer.args.train_batch_size * trainer.args.world_size * trainer.args.gradient_accumulation_steps
+ )
+ if total_train_batch_size != hyperparameters["train_batch_size"]:
+ hyperparameters["total_train_batch_size"] = total_train_batch_size
+ total_eval_batch_size = trainer.args.eval_batch_size * trainer.args.world_size
+ if total_eval_batch_size != hyperparameters["eval_batch_size"]:
+ hyperparameters["total_eval_batch_size"] = total_eval_batch_size
+
+ if trainer.args.optim:
+ optimizer_name = trainer.args.optim
+ optimizer_args = trainer.args.optim_args if trainer.args.optim_args else "No additional optimizer arguments"
+
+ if "adam" in optimizer_name.lower():
+ hyperparameters["optimizer"] = (
+ f"Use {optimizer_name} with betas=({trainer.args.adam_beta1},{trainer.args.adam_beta2}) and"
+ f" epsilon={trainer.args.adam_epsilon} and optimizer_args={optimizer_args}"
+ )
+ else:
+ hyperparameters["optimizer"] = f"Use {optimizer_name} and the args are:\n{optimizer_args}"
+
+ hyperparameters["lr_scheduler_type"] = trainer.args.lr_scheduler_type.value
+ if trainer.args.warmup_ratio != 0.0:
+ hyperparameters["lr_scheduler_warmup_ratio"] = trainer.args.warmup_ratio
+ if trainer.args.warmup_steps != 0.0:
+ hyperparameters["lr_scheduler_warmup_steps"] = trainer.args.warmup_steps
+ if trainer.args.max_steps != -1:
+ hyperparameters["training_steps"] = trainer.args.max_steps
+ else:
+ hyperparameters["num_epochs"] = trainer.args.num_train_epochs
+
+ if trainer.args.fp16:
+ if trainer.use_apex:
+ hyperparameters["mixed_precision_training"] = f"Apex, opt level {trainer.args.fp16_opt_level}"
+ else:
+ hyperparameters["mixed_precision_training"] = "Native AMP"
+
+ if trainer.args.label_smoothing_factor != 0.0:
+ hyperparameters["label_smoothing_factor"] = trainer.args.label_smoothing_factor
+
+ return hyperparameters
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_attn_mask_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_attn_mask_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0be1b3ed25531a6bd67aa143b0883815b358678d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_attn_mask_utils.py
@@ -0,0 +1,487 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+IMPORTANT NOTICE: Every class and function in this file is deprecated in favor of using the much more general
+`masking_utils.py` primitives. New code should not rely on it, it is only kept for backward compatibility for now,
+and will be removed in the future.
+"""
+
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+
+from .utils.import_utils import is_torchdynamo_compiling
+
+
+@dataclass
+class AttentionMaskConverter:
+ """
+ A utility attention mask class that allows one to:
+ - Create a causal 4d mask
+ - Create a causal 4d mask with slided window
+ - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
+ key_value_length) that can be multiplied with attention scores
+
+ Examples:
+
+ ```python
+ >>> import torch
+ >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+
+ >>> converter = AttentionMaskConverter(True)
+ >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
+ tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
+ ```
+
+ Parameters:
+ is_causal (`bool`):
+ Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
+
+ sliding_window (`int`, *optional*):
+ Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
+ """
+
+ is_causal: bool
+ sliding_window: int
+
+ def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
+ self.is_causal = is_causal
+ self.sliding_window = sliding_window
+
+ if self.sliding_window is not None and self.sliding_window <= 0:
+ raise ValueError(
+ f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
+ )
+
+ def to_causal_4d(
+ self,
+ batch_size: int,
+ query_length: int,
+ key_value_length: int,
+ dtype: torch.dtype,
+ device: Union[torch.device, "str"] = "cpu",
+ ) -> Optional[torch.Tensor]:
+ """
+ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
+ bias to upper right hand triangular matrix (causal mask).
+ """
+ if not self.is_causal:
+ raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
+
+ # If shape is not cached, create a new causal mask and cache it
+ input_shape = (batch_size, query_length)
+ past_key_values_length = key_value_length - query_length
+
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ causal_4d_mask = None
+ if input_shape[-1] > 1 or self.sliding_window is not None:
+ causal_4d_mask = self._make_causal_mask(
+ input_shape,
+ dtype,
+ device=device,
+ past_key_values_length=past_key_values_length,
+ sliding_window=self.sliding_window,
+ )
+
+ return causal_4d_mask
+
+ def to_4d(
+ self,
+ attention_mask_2d: torch.Tensor,
+ query_length: int,
+ dtype: torch.dtype,
+ key_value_length: Optional[int] = None,
+ ) -> torch.Tensor:
+ """
+ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
+ key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
+ causal, a causal mask will be added.
+ """
+ input_shape = (attention_mask_2d.shape[0], query_length)
+
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ causal_4d_mask = None
+ if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
+ if key_value_length is None:
+ raise ValueError(
+ "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
+ )
+
+ past_key_values_length = key_value_length - query_length
+ causal_4d_mask = self._make_causal_mask(
+ input_shape,
+ dtype,
+ device=attention_mask_2d.device,
+ past_key_values_length=past_key_values_length,
+ sliding_window=self.sliding_window,
+ )
+ elif self.sliding_window is not None:
+ raise NotImplementedError("Sliding window is currently only implemented for causal masking")
+
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
+ attention_mask_2d.device
+ )
+
+ if causal_4d_mask is not None:
+ expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
+
+ # expanded_attn_mask + causal_4d_mask can cause some overflow
+ expanded_4d_mask = expanded_attn_mask
+
+ return expanded_4d_mask
+
+ @staticmethod
+ def _make_causal_mask(
+ input_ids_shape: torch.Size,
+ dtype: torch.dtype,
+ device: torch.device,
+ past_key_values_length: int = 0,
+ sliding_window: Optional[int] = None,
+ ):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
+ mask_cond = torch.arange(mask.size(-1), device=device)
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+
+ # add lower triangular sliding window mask if necessary
+ if sliding_window is not None:
+ diagonal = past_key_values_length - sliding_window - 1
+
+ context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
+ # Recent changes in PyTorch prevent mutations on tensors converted with aten::_to_copy
+ # See https://github.com/pytorch/pytorch/issues/127571
+ if is_torchdynamo_compiling():
+ mask = mask.clone()
+ mask.masked_fill_(context_mask, torch.finfo(dtype).min)
+
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+ @staticmethod
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = torch.tensor(1.0, dtype=dtype) - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+ @staticmethod
+ def _unmask_unattended(
+ expanded_mask: torch.FloatTensor,
+ min_dtype: float,
+ ):
+ # fmt: off
+ """
+ Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
+ using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ Details: https://github.com/pytorch/pytorch/issues/110213
+
+ `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
+ `attention_mask` is [bsz, src_seq_len].
+
+ The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
+
+ For example, if `expanded_mask` is (e.g. here left-padding case)
+ ```
+ [[[[0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 1]]],
+ [[[1, 0, 0],
+ [1, 1, 0],
+ [1, 1, 1]]],
+ [[[0, 0, 0],
+ [0, 1, 0],
+ [0, 1, 1]]]]
+ ```
+ then the modified `expanded_mask` will be
+ ```
+ [[[[1, 1, 1], <-- modified
+ [1, 1, 1], <-- modified
+ [0, 0, 1]]],
+ [[[1, 0, 0],
+ [1, 1, 0],
+ [1, 1, 1]]],
+ [[[1, 1, 1], <-- modified
+ [0, 1, 0],
+ [0, 1, 1]]]]
+ ```
+ """
+ # fmt: on
+ if expanded_mask.dtype == torch.bool:
+ raise ValueError(
+ "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
+ )
+
+ return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
+
+ @staticmethod
+ def _ignore_causal_mask_sdpa(
+ attention_mask: Optional[torch.Tensor],
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+ sliding_window: Optional[int] = None,
+ is_training: bool = False,
+ ) -> bool:
+ """
+ Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
+ ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
+
+ In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
+ `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
+ passed).
+ """
+
+ _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
+ key_value_length = query_length + past_key_values_length
+
+ is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()
+
+ ignore_causal_mask = False
+
+ if attention_mask is None:
+ # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input
+ # shape, thus SDPA's `is_causal` argument is rightfully updated
+ # (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using
+ # `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
+ # hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True`
+ # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
+ # Thus, we only set `ignore_causal_mask = True` if the model is set to training.
+ #
+ # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal`
+ # ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
+ if (
+ (is_training or not is_tracing)
+ and (query_length == 1 or key_value_length == query_length)
+ and (sliding_window is None or key_value_length < sliding_window)
+ ):
+ ignore_causal_mask = True
+ elif sliding_window is None or key_value_length < sliding_window:
+ if len(attention_mask.shape) == 4:
+ return False
+ elif not is_tracing and torch.all(attention_mask == 1):
+ if query_length == 1 or key_value_length == query_length:
+ # For query_length == 1, causal attention and bi-directional attention are the same.
+ ignore_causal_mask = True
+
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore
+ # the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in
+ # SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
+ # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
+
+ return ignore_causal_mask
+
+
+def _prepare_4d_causal_attention_mask(
+ attention_mask: Optional[torch.Tensor],
+ input_shape: Union[torch.Size, tuple, list],
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+ sliding_window: Optional[int] = None,
+):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`
+
+ Args:
+ attention_mask (`torch.Tensor` or `None`):
+ A 2D attention mask of shape `(batch_size, key_value_length)`
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
+ inputs_embeds (`torch.Tensor`):
+ The embedded inputs as a torch Tensor.
+ past_key_values_length (`int`):
+ The length of the key value cache.
+ sliding_window (`int`, *optional*):
+ If the model uses windowed attention, a sliding window should be passed.
+ """
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
+
+ key_value_length = input_shape[-1] + past_key_values_length
+
+ # 4d mask is passed through the layers
+ if attention_mask is not None and len(attention_mask.shape) == 2:
+ attention_mask = attn_mask_converter.to_4d(
+ attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
+ )
+ elif attention_mask is not None and len(attention_mask.shape) == 4:
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
+ if tuple(attention_mask.shape) != expected_shape:
+ raise ValueError(
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
+ )
+ else:
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
+ inverted_mask = 1.0 - attention_mask
+ attention_mask = inverted_mask.masked_fill(
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
+ )
+ else:
+ attention_mask = attn_mask_converter.to_causal_4d(
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
+ )
+
+ return attention_mask
+
+
+# Adapted from _prepare_4d_causal_attention_mask
+def _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask: Optional[torch.Tensor],
+ input_shape: Union[torch.Size, tuple, list],
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+ sliding_window: Optional[int] = None,
+):
+ """
+ Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
+
+ In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
+ `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
+ """
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
+
+ key_value_length = input_shape[-1] + past_key_values_length
+
+ # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
+ # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
+ is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()
+
+ ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ sliding_window=sliding_window,
+ )
+
+ if ignore_causal_mask:
+ expanded_4d_mask = None
+ elif attention_mask is None:
+ expanded_4d_mask = attn_mask_converter.to_causal_4d(
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
+ )
+ else:
+ if attention_mask.dim() == 4:
+ expanded_4d_mask = attention_mask
+ else:
+ expanded_4d_mask = attn_mask_converter.to_4d(
+ attention_mask,
+ input_shape[-1],
+ dtype=inputs_embeds.dtype,
+ key_value_length=key_value_length,
+ )
+
+ # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ if not is_tracing and expanded_4d_mask.device.type == "cuda":
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
+ expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
+ )
+
+ return expanded_4d_mask
+
+
+def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`
+
+ Args:
+ mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)`
+ dtype (`torch.dtype`):
+ The torch dtype the created mask shall have.
+ tgt_len (`int`):
+ The target length or query length the created mask shall have.
+ """
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
+
+
+def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`
+
+ Args:
+ mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)`
+ dtype (`torch.dtype`):
+ The torch dtype the created mask shall have.
+ tgt_len (`int`):
+ The target length or query length the created mask shall have.
+ """
+ _, key_value_length = mask.shape
+ tgt_len = tgt_len if tgt_len is not None else key_value_length
+
+ is_tracing = torch.jit.is_tracing() or isinstance(mask, torch.fx.Proxy) or is_torchdynamo_compiling()
+
+ # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows.
+ if not is_tracing and torch.all(mask == 1):
+ return None
+ else:
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
+
+
+def _create_4d_causal_attention_mask(
+ input_shape: Union[torch.Size, tuple, list],
+ dtype: torch.dtype,
+ device: torch.device,
+ past_key_values_length: int = 0,
+ sliding_window: Optional[int] = None,
+) -> Optional[torch.Tensor]:
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
+
+ Args:
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
+ dtype (`torch.dtype`):
+ The torch dtype the created mask shall have.
+ device (`int`):
+ The torch device the created mask shall have.
+ sliding_window (`int`, *optional*):
+ If the model uses windowed attention, a sliding window should be passed.
+ """
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
+
+ key_value_length = past_key_values_length + input_shape[-1]
+ attention_mask = attn_mask_converter.to_causal_4d(
+ input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
+ )
+
+ return attention_mask
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5312b0dd9cd0b40f6c7b64356de203f579e8b263
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py
@@ -0,0 +1,668 @@
+# Copyright 2025 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import inspect
+import os
+from functools import partial
+from typing import Optional, TypedDict
+
+import torch
+import torch.nn.functional as F
+
+from .utils import (
+ is_flash_attn_2_available,
+ is_flash_attn_3_available,
+ is_flash_attn_greater_or_equal_2_10,
+ is_torch_npu_available,
+ logging,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+# TODO Deprecate when all models have the attention interface
+def flash_attn_supports_top_left_mask():
+ if is_flash_attn_3_available():
+ return False
+ if is_flash_attn_2_available():
+ return not is_flash_attn_greater_or_equal_2_10()
+
+ from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask
+
+ return is_npu_fa2_top_left_aligned_causal_mask()
+
+
+# TODO Deprecate when all models have the attention interface
+def is_flash_attn_available():
+ return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available()
+
+
+# `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves
+_flash_fn = None
+_flash_varlen_fn = None
+_pad_fn = None
+_unpad_fn = None
+
+# function that processes kwargs, generalized to handle any supported kwarg within the function
+_process_flash_kwargs_fn = None
+# exceptions where hf API doesn't match the original flash attention API
+_hf_api_to_flash_mapping = {
+ "dropout": "dropout_p",
+ "sliding_window": "window_size",
+}
+
+
+def _lazy_imports(implementation: Optional[str]):
+ """
+ Lazy loads the respective flash attention implementations.
+
+ Return:
+ flash_attn_func: The base flash attention function.
+ flash_attn_varlen_func: The flash attention function supporting variable sequence lengths,
+ e.g. for padding-free training.
+ pad_input: The function to pad inputs into one sequence and returning the respective kwargs.
+ unpad_input: The function to unpad outputs based on the kwargs (from pad_input).
+ """
+ is_fa2 = is_flash_attn_2_available()
+ is_fa3 = is_flash_attn_3_available()
+
+ pad_input, unpad_input = _pad_input, _unpad_input
+
+ if (implementation == "flash_attention_2" and is_fa2) or (implementation is None and is_fa2 and not is_fa3):
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import pad_input, unpad_input
+ elif is_torch_npu_available():
+ # Package `flash-attn` is unavailable on Ascend NPU, which will cause ImportError
+ # Flash-Attention2 related apis for Ascend NPU must be imported from `.integrations.npu_flash_attention` module
+ from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
+ from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
+ else:
+ if implementation == "flash_attention_3" or (implementation is None and is_fa3):
+ from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
+ # Kernels fallback
+ else:
+ flash_attn_func = getattr(implementation, "flash_attn_func", None)
+ flash_attn_varlen_func = getattr(implementation, "flash_attn_varlen_func", None)
+ if flash_attn_varlen_func is None or flash_attn_func is None:
+ raise ValueError(
+ f"Could not find the currently requested flash attention implementation at `{implementation}`."
+ f"Make sure that you request a valid kernel from the hub, e.g. `kernels-community/flash-attn`."
+ )
+
+ return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input
+
+
+def _lazy_define_process_function(flash_function):
+ """
+ Depending on the version and kernel some features are not supported. Due to limitations in
+ `torch.compile`, we opt to statically type which (optional) kwarg parameters are supported
+ within `_process_flash_attention_kwargs`.
+
+ NOTE: While all supported kwargs are marked as `True`, everything else is marked as `False`.
+ This might be confusing for kwargs that we use in any case, e.g. `is_causal`.
+ """
+
+ flash_parameters = inspect.signature(flash_function).parameters
+ process_parameters = inspect.signature(_process_flash_attention_kwargs).parameters
+
+ supports_mapping = {}
+ for param in process_parameters:
+ fa_param = _hf_api_to_flash_mapping.get(param, param)
+ supports_mapping[fa_param] = fa_param in flash_parameters
+
+ return partial(_process_flash_attention_kwargs, supports_mapping=supports_mapping)
+
+
+def lazy_import_flash_attention(implementation: Optional[str], force_import: Optional[bool] = False):
+ """
+ Lazily import flash attention and return the respective functions + flags.
+
+ NOTE: For fullgraph, this needs to be called before compile, while no fullgraph can
+ work without preloading. See `load_and_register_kernel` in `integrations.hub_kernels`.
+ """
+ global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn
+ if force_import or any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]):
+ _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn = _lazy_imports(implementation)
+
+ global _process_flash_kwargs_fn
+ if force_import or _process_flash_kwargs_fn is None:
+ _process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn)
+
+ return (_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn
+
+
+def _index_first_axis(tensor, indices):
+ """
+ A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
+ after flattening the first two dimensions of the tensor. This is functionally equivalent to
+ FA2's `index_first_axis` and replaces the need to import it.
+ """
+ # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
+ # two dimensions to get (total_tokens, ...) before indexing.
+ reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:])
+ return reshaped_tensor[indices]
+
+
+def _unpad_input(hidden_states, attention_mask, unused_mask=None):
+ """
+ unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
+
+ Arguments:
+ hidden_states: (batch, seqlen, ...)
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
+ unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
+
+ Return:
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
+ indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
+ max_seqlen_in_batch: int
+ seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
+ """
+ all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
+ seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
+ used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+
+ return (
+ _index_first_axis(hidden_states, indices),
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ used_seqlens_in_batch,
+ )
+
+
+def _pad_input(hidden_states, indices, batch, seqlen):
+ """
+ pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
+
+ Arguments:
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
+ batch: int, batch size for the padded sequence.
+ seqlen: int, maximum sequence length for the padded sequence.
+
+ Return:
+ hidden_states: (batch, seqlen, ...)
+ """
+ dim = hidden_states.shape[1:]
+ output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
+ output[indices] = hidden_states
+ return output.view(batch, seqlen, *dim)
+
+
+def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
+ """
+ Retrieves indexing data required to repad unpadded (ragged) tensors.
+
+ Arguments:
+ attention_mask (`torch.Tensor`):
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
+
+ Return:
+ indices (`torch.Tensor`):
+ The indices of non-masked tokens from the flattened input sequence.
+ cu_seqlens (`torch.Tensor`):
+ The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+ max_seqlen_in_batch (`int`):
+ Maximum sequence length in batch.
+ """
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ # NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile,
+ # this might cause a graph break
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+def _upad_input(
+ query_layer: torch.Tensor,
+ key_layer: torch.Tensor,
+ value_layer: torch.Tensor,
+ attention_mask: torch.Tensor,
+ query_length: int,
+ unpad_input_func,
+):
+ """
+ Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
+ This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
+ tensors for query, key, value tensors.
+
+ Arguments:
+ query_layer (`torch.Tensor`):
+ Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
+ key_layer (`torch.Tensor`):
+ Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
+ value_layer (`torch.Tensor`):
+ Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
+ attention_mask (`torch.Tensor`):
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
+ query_length (`int`):
+ Target length.
+ unpad_input_func:
+ The function to use for unpadding the input tensors.
+
+ Return:
+ query_layer (`torch.Tensor`):
+ Query state without padding. Shape: (total_target_length, num_heads, head_dim).
+ key_layer (`torch.Tensor`):
+ Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
+ value_layer (`torch.Tensor`):
+ Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
+ indices_q (`torch.Tensor`):
+ The indices of non-masked tokens from the flattened input target sequence.
+ (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
+ The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
+ Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
+ """
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+
+ # With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage
+ # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores
+ if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]):
+ key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :]
+
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = _index_first_axis(key_layer, indices_k)
+ value_layer = _index_first_axis(value_layer, indices_k)
+ if query_length == kv_seq_len:
+ query_layer = _index_first_axis(query_layer, indices_k)
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+def prepare_fa_kwargs_from_position_ids(position_ids):
+ """
+ This function returns all the necessary kwargs to call `flash_attn_varlen_func` extracted from position_ids.
+
+ Arguments:
+ position_ids (`torch.Tensor`):
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
+
+ Return:
+ (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
+ The cumulative sequence lengths for the target (query) and source (key, value), used to index into
+ ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
+ Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query,
+ `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
+ """
+ tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
+
+ position_ids = position_ids.view(-1)
+ indices_q = (position_ids == 0).nonzero().view(-1)
+
+ cu_seq_lens_q = torch.cat(
+ (
+ indices_q.to(**tensor_kwargs),
+ torch.tensor(position_ids.size(), **tensor_kwargs),
+ )
+ )
+ cu_seq_lens_k = cu_seq_lens_q
+
+ # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
+ # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
+ # for some models (e.g. qwen2-vl).
+ max_length_q = cu_seq_lens_q.diff().max()
+ # NOTE: With torch compile, this will cause a graph break if you don't set
+ # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
+ # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
+ # This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
+ # requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
+ max_length_q = max_length_q.item()
+ max_length_k = max_length_q
+
+ return (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k)
+
+
+def _prepare_from_posids(query, key, value, position_ids):
+ """
+ This function returns necessary arguments to call `flash_attn_varlen_func`.
+ All three query, key, value states will be flattened.
+ Cumulative lengths of each examples in the batch will be extracted from position_ids.
+ NOTE: ideally cumulative lengths should be prepared at the data collator stage
+
+ Arguments:
+ query (`torch.Tensor`):
+ Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
+ key (`torch.Tensor`):
+ Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
+ value (`torch.Tensor`):
+ Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
+ position_ids (`torch.Tensor`):
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
+
+ Return:
+ query (`torch.Tensor`):
+ Query state without padding. Shape: (total_target_length, num_heads, head_dim).
+ key (`torch.Tensor`):
+ Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
+ value (`torch.Tensor`):
+ Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
+ (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
+ The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
+ Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
+ """
+ query = query.contiguous().view(-1, query.size(-2), query.size(-1))
+ key = key.contiguous().view(-1, key.size(-2), key.size(-1))
+ value = value.contiguous().view(-1, value.size(-2), value.size(-1))
+
+ (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fa_kwargs_from_position_ids(position_ids)
+
+ return (query, key, value, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k))
+
+
+def _is_packed_sequence(position_ids, batch_size):
+ """
+ Check the position ids whether packed sequences are indicated or not
+ 1. Position ids exist
+ 2. Flattened sequences only are supported
+ 3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences
+ """
+ if position_ids is None:
+ return False
+
+ increasing_position_sequences = (
+ torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min()
+ )
+ return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool()
+
+
+def fa_peft_integration_check(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ target_dtype: Optional[torch.dtype] = None,
+):
+ """
+ PEFT usually casts the layer norms in float32 for training stability reasons
+ therefore the input hidden states gets silently casted in float32. Hence, we need
+ cast them back in float16 / bfloat16 just to be sure everything works as expected.
+ This might slowdown training & inference so it is recommended to not cast the LayerNorms!
+ """
+ if target_dtype and q.dtype == torch.float32:
+ logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility.")
+ q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype)
+ return q, k, v
+
+
+class FlashAttentionKwargs(TypedDict, total=False):
+ """
+ Keyword arguments for Flash Attention with Compile.
+
+ Attributes:
+ cu_seq_lens_q (`torch.LongTensor`, *optional*)
+ Gets cumulative sequence length for query state.
+ cu_seq_lens_k (`torch.LongTensor`, *optional*)
+ Gets cumulative sequence length for key state.
+ max_length_q (`int`, *optional*):
+ Maximum sequence length for query state.
+ max_length_k (`int`, *optional*):
+ Maximum sequence length for key state.
+ """
+
+ cu_seq_lens_q: Optional[torch.LongTensor]
+ cu_seq_lens_k: Optional[torch.LongTensor]
+ max_length_q: Optional[int]
+ max_length_k: Optional[int]
+
+
+def _process_flash_attention_kwargs(
+ query_length: int,
+ key_length: int,
+ is_causal: bool,
+ dropout: float = 0.0,
+ softmax_scale: Optional[float] = None,
+ sliding_window: Optional[int] = None,
+ use_top_left_mask: bool = False,
+ softcap: Optional[float] = None,
+ deterministic: Optional[bool] = None,
+ s_aux: Optional[torch.Tensor] = None,
+ supports_mapping: Optional[dict[str, bool]] = None,
+ **kwargs,
+):
+ """
+ Returns a set of kwargs that are passed down to the according flash attention function based on
+ requested features and whether it is supported - depends on the version and kernel implementation
+ which is dynamically configured at `lazy_import_flash_attention`. The (un)supported features can be
+ inspected in `supports_mapping`, see `_lazy_define_process_function` for more details.
+
+ Args:
+ query_length (`int`):
+ Length of the query states
+ key_length (`int`):
+ Length of the key states
+ is_causal (`bool`):
+ Whether we perform causal (decoder) attention or full attention.
+ dropout (`float`):
+ Attention dropout.
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to `1 / sqrt(head_dim)`.
+ sliding_window (`int`, *optional*):
+ The size of the sliding window, i.e. we look at a max of `sliding_window` tokens back.
+ use_top_left_mask (`bool`):
+ Deprecated behavior of older versions of flash attention requiring different masking.
+ softcap (`float`, *optional*):
+ Softcap for the attention logits, used e.g. in gemma2.
+ deterministic (`bool`, *optional*):
+ Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
+ s_aux (`torch.Tensor`, *optional*):
+ Attention sink auxiliary that adds a `bias` to the attention calculation via an additional head.
+ Return:
+ flash_kwargs (`dict`):
+ A dict of kwargs that are requested and supported.
+ """
+ flash_kwargs = {
+ "causal": is_causal and not (use_top_left_mask and query_length == 1),
+ "softmax_scale": softmax_scale,
+ }
+
+ if supports_mapping["dropout_p"]:
+ flash_kwargs["dropout_p"] = dropout
+
+ if supports_mapping["window_size"] and sliding_window is not None and key_length > sliding_window:
+ # The flash attention API sets inclusive boundaries, i.e. (4, 0) would take 4 tokens to the left
+ # and the current token for a total size of 5. However, we usually define our window sizes by
+ # their total window size (when causal). Encoder models as of now seldom use SWA and when they
+ # do, they have a custom workaround (e.g. ModernBERT) which would align with this symmetric logic, i.e.
+ # for a total of `2*sliding_window + 1`.
+ flash_kwargs["window_size"] = (sliding_window - 1, sliding_window - 1)
+
+ if supports_mapping["deterministic"]:
+ flash_kwargs["deterministic"] = (
+ deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
+ )
+
+ if supports_mapping["softcap"] and softcap is not None:
+ flash_kwargs["softcap"] = softcap
+
+ # Only within kernel implementation atm
+ if supports_mapping["s_aux"] and s_aux is not None:
+ flash_kwargs["s_aux"] = s_aux
+
+ return flash_kwargs
+
+
+def _flash_attention_forward(
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ query_length: int,
+ is_causal: bool,
+ dropout: float = 0.0,
+ position_ids: Optional[torch.Tensor] = None,
+ softmax_scale: Optional[float] = None,
+ sliding_window: Optional[int] = None,
+ use_top_left_mask: bool = False,
+ softcap: Optional[float] = None,
+ deterministic: Optional[bool] = None,
+ cu_seq_lens_q: Optional[torch.LongTensor] = None,
+ cu_seq_lens_k: Optional[torch.LongTensor] = None,
+ max_length_q: Optional[int] = None,
+ max_length_k: Optional[int] = None,
+ target_dtype: Optional[torch.dtype] = None,
+ implementation: Optional[str] = None,
+ **kwargs,
+):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ (Optional) kwargs are described further in `_process_flash_attention_kwargs` and `FlashAttentionKwargs`.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`, *optional*):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ implementation (`str`, *optional*):
+ The attention implementation to use. If None, will default to the one based on the environment.
+ """
+ (flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_flash_kwargs_fn = lazy_import_flash_attention(
+ implementation
+ )
+
+ # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
+ query_states, key_states, value_states = fa_peft_integration_check(
+ query_states, key_states, value_states, target_dtype
+ )
+
+ # Extract the flash attention kwargs that have been requested (and are supported by the implementation)
+ flash_kwargs = process_flash_kwargs_fn(
+ query_length=query_length,
+ key_length=key_states.size(1),
+ is_causal=is_causal,
+ dropout=dropout,
+ softmax_scale=softmax_scale,
+ sliding_window=sliding_window,
+ use_top_left_mask=use_top_left_mask,
+ softcap=softcap,
+ deterministic=deterministic,
+ **kwargs,
+ )
+
+ # We will use `flash_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases:
+ # Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`.
+ # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
+ # use `flash_varlen_fn` knowing we already have all necessary the kwargs.
+ #
+ # NOTE: it is user's responsibility to take care of flattening `position_ids` if that's needed by the model.
+ # See #39121 for more information.
+ is_fa_with_position_ids = _is_packed_sequence(position_ids, batch_size=query_states.size(0))
+ is_fa_with_varlen_kwargs = all(
+ kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k)
+ )
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input(
+ query_states, key_states, value_states, attention_mask, query_length, unpad_fn
+ )
+
+ # TODO for now this is required to work with
+ # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py
+ if "mps" in str(q.device):
+ cu_seq_lens_k = cu_seq_lens_k.clone()
+
+ out_unpad = flash_varlen_fn(
+ q,
+ k,
+ v,
+ cu_seqlens_q=cu_seq_lens_q,
+ cu_seqlens_k=cu_seq_lens_k,
+ max_seqlen_q=max_length_q,
+ max_seqlen_k=max_length_k,
+ **flash_kwargs,
+ )
+ if isinstance(out_unpad, tuple):
+ out_unpad = out_unpad[0]
+
+ out = pad_fn(out_unpad, indices_q, query_states.size(0), query_length)
+
+ # Padding free, i.e. sequences flattened into one total sequence
+ elif is_fa_with_varlen_kwargs or is_fa_with_position_ids:
+ if cu_seq_lens_q is None or cu_seq_lens_k is None:
+ q, k, v, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _prepare_from_posids(
+ query_states, key_states, value_states, position_ids
+ )
+ else:
+ q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
+ k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
+ v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
+
+ # TODO for now this is required to work with
+ # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py
+ if "mps" in str(q.device):
+ cu_seq_lens_k = cu_seq_lens_k.clone()
+
+ out = flash_varlen_fn(
+ q,
+ k,
+ v,
+ cu_seqlens_q=cu_seq_lens_q,
+ cu_seqlens_k=cu_seq_lens_k,
+ max_seqlen_q=max_length_q,
+ max_seqlen_k=max_length_k,
+ **flash_kwargs,
+ )
+ if isinstance(out, tuple):
+ out = out[0]
+
+ out = out.view(query_states.size(0), -1, out.size(-2), out.size(-1))
+
+ # No padding
+ else:
+ out = flash_fn(query_states, key_states, value_states, **flash_kwargs)
+ if isinstance(out, tuple):
+ out = out[0]
+
+ return out
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_flax_outputs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_flax_outputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a25a6059a255659c6d900b35d2ffa7cab57f071
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_flax_outputs.py
@@ -0,0 +1,700 @@
+# Copyright 2021 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+import flax
+import jax.numpy as jnp
+
+from .utils import ModelOutput
+
+
+@flax.struct.dataclass
+class FlaxBaseModelOutput(ModelOutput):
+ """
+ Base class for model's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: Optional[jnp.ndarray] = None
+ hidden_states: Optional[tuple[jnp.ndarray]] = None
+ attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxBaseModelOutputWithNoAttention(ModelOutput):
+ """
+ Base class for model's outputs, with potential hidden states.
+
+ Args:
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
+ for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the
+ model at the output of each layer plus the optional initial embedding outputs.
+ """
+
+ last_hidden_state: Optional[jnp.ndarray] = None
+ hidden_states: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxBaseModelOutputWithPoolingAndNoAttention(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
+ Last layer hidden-state after a pooling operation on the spatial dimensions.
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
+ for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the
+ model at the output of each layer plus the optional initial embedding outputs.
+ """
+
+ last_hidden_state: Optional[jnp.ndarray] = None
+ pooler_output: Optional[jnp.ndarray] = None
+ hidden_states: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxImageClassifierOutputWithNoAttention(ModelOutput):
+ """
+ Base class for outputs of image classification models.
+
+ Args:
+ logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when
+ `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
+ for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
+ called feature maps) of the model at the output of each stage.
+ """
+
+ logits: Optional[jnp.ndarray] = None
+ hidden_states: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxBaseModelOutputWithPast(ModelOutput):
+ """
+ Base class for model's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ past_key_values (`dict[str, jnp.ndarray]`):
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: Optional[jnp.ndarray] = None
+ past_key_values: Optional[dict[str, jnp.ndarray]] = None
+ hidden_states: Optional[tuple[jnp.ndarray]] = None
+ attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxBaseModelOutputWithPooling(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
+ Last layer hidden-state of the first token of the sequence (classification token) further processed by a
+ Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
+ prediction (classification) objective during pretraining.
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: Optional[jnp.ndarray] = None
+ pooler_output: Optional[jnp.ndarray] = None
+ hidden_states: Optional[tuple[jnp.ndarray]] = None
+ attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
+ Last layer hidden-state of the first token of the sequence (classification token) after further processing
+ through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
+ the classification token after processing through a linear layer and a tanh activation function. The linear
+ layer weights are trained from the next sentence prediction (classification) objective during pretraining.
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
+ for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
+ encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+ input) to speed up sequential decoding.
+ """
+
+ last_hidden_state: Optional[jnp.ndarray] = None
+ pooler_output: Optional[jnp.ndarray] = None
+ hidden_states: Optional[tuple[jnp.ndarray]] = None
+ past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
+ attentions: Optional[tuple[jnp.ndarray]] = None
+ cross_attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
+ """
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+ Args:
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
+ encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+ input) to speed up sequential decoding.
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ """
+
+ last_hidden_state: Optional[jnp.ndarray] = None
+ past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
+ hidden_states: Optional[tuple[jnp.ndarray]] = None
+ attentions: Optional[tuple[jnp.ndarray]] = None
+ cross_attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxSeq2SeqModelOutput(ModelOutput):
+ """
+ Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
+ decoding.
+
+ Args:
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ """
+
+ last_hidden_state: Optional[jnp.ndarray] = None
+ past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
+ decoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
+ decoder_attentions: Optional[tuple[jnp.ndarray]] = None
+ cross_attentions: Optional[tuple[jnp.ndarray]] = None
+ encoder_last_hidden_state: Optional[jnp.ndarray] = None
+ encoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
+ encoder_attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxCausalLMOutputWithCrossAttentions(ModelOutput):
+ """
+ Base class for causal language model (or autoregressive) outputs.
+
+ Args:
+ logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Cross attentions weights after the attention softmax, used to compute the weighted average in the
+ cross-attention heads.
+ past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `jnp.ndarray` tuples of length `config.n_layers`, with each tuple containing the cached key, value
+ states of the self-attention and the cross-attention layers if model is used in encoder-decoder setting.
+ Only relevant if `config.is_decoder = True`.
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ """
+
+ logits: Optional[jnp.ndarray] = None
+ past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
+ hidden_states: Optional[tuple[jnp.ndarray]] = None
+ attentions: Optional[tuple[jnp.ndarray]] = None
+ cross_attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxMaskedLMOutput(ModelOutput):
+ """
+ Base class for masked language models outputs.
+
+ Args:
+ logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ logits: Optional[jnp.ndarray] = None
+ hidden_states: Optional[tuple[jnp.ndarray]] = None
+ attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+FlaxCausalLMOutput = FlaxMaskedLMOutput
+
+
+@flax.struct.dataclass
+class FlaxSeq2SeqLMOutput(ModelOutput):
+ """
+ Base class for sequence-to-sequence language models outputs.
+
+ Args:
+ logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ """
+
+ logits: Optional[jnp.ndarray] = None
+ past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
+ decoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
+ decoder_attentions: Optional[tuple[jnp.ndarray]] = None
+ cross_attentions: Optional[tuple[jnp.ndarray]] = None
+ encoder_last_hidden_state: Optional[jnp.ndarray] = None
+ encoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
+ encoder_attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxNextSentencePredictorOutput(ModelOutput):
+ """
+ Base class for outputs of models predicting if two sentences are consecutive or not.
+
+ Args:
+ logits (`jnp.ndarray` of shape `(batch_size, 2)`):
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
+ before SoftMax).
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ logits: Optional[jnp.ndarray] = None
+ hidden_states: Optional[tuple[jnp.ndarray]] = None
+ attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxSequenceClassifierOutput(ModelOutput):
+ """
+ Base class for outputs of sentence classification models.
+
+ Args:
+ logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ logits: Optional[jnp.ndarray] = None
+ hidden_states: Optional[tuple[jnp.ndarray]] = None
+ attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxSeq2SeqSequenceClassifierOutput(ModelOutput):
+ """
+ Base class for outputs of sequence-to-sequence sentence classification models.
+
+ Args:
+ logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ """
+
+ logits: Optional[jnp.ndarray] = None
+ past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
+ decoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
+ decoder_attentions: Optional[tuple[jnp.ndarray]] = None
+ cross_attentions: Optional[tuple[jnp.ndarray]] = None
+ encoder_last_hidden_state: Optional[jnp.ndarray] = None
+ encoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
+ encoder_attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxMultipleChoiceModelOutput(ModelOutput):
+ """
+ Base class for outputs of multiple choice models.
+
+ Args:
+ logits (`jnp.ndarray` of shape `(batch_size, num_choices)`):
+ *num_choices* is the second dimension of the input tensors. (see *input_ids* above).
+
+ Classification scores (before SoftMax).
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ logits: Optional[jnp.ndarray] = None
+ hidden_states: Optional[tuple[jnp.ndarray]] = None
+ attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxTokenClassifierOutput(ModelOutput):
+ """
+ Base class for outputs of token classification models.
+
+ Args:
+ logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.num_labels)`):
+ Classification scores (before SoftMax).
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ logits: Optional[jnp.ndarray] = None
+ hidden_states: Optional[tuple[jnp.ndarray]] = None
+ attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxQuestionAnsweringModelOutput(ModelOutput):
+ """
+ Base class for outputs of question answering models.
+
+ Args:
+ start_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+ Span-start scores (before SoftMax).
+ end_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+ Span-end scores (before SoftMax).
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ start_logits: Optional[jnp.ndarray] = None
+ end_logits: Optional[jnp.ndarray] = None
+ hidden_states: Optional[tuple[jnp.ndarray]] = None
+ attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
+ """
+ Base class for outputs of sequence-to-sequence question answering models.
+
+ Args:
+ start_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+ Span-start scores (before SoftMax).
+ end_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+ Span-end scores (before SoftMax).
+ past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ """
+
+ start_logits: Optional[jnp.ndarray] = None
+ end_logits: Optional[jnp.ndarray] = None
+ past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
+ decoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
+ decoder_attentions: Optional[tuple[jnp.ndarray]] = None
+ cross_attentions: Optional[tuple[jnp.ndarray]] = None
+ encoder_last_hidden_state: Optional[jnp.ndarray] = None
+ encoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
+ encoder_attentions: Optional[tuple[jnp.ndarray]] = None
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_flax_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_flax_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc9a4d473f36f95bb13d6de17dc0bfa7cfae2279
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_flax_utils.py
@@ -0,0 +1,1274 @@
+# coding=utf-8
+# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import gc
+import json
+import os
+import warnings
+from functools import partial
+from pickle import UnpicklingError
+from typing import Any, Optional, Union
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+import msgpack.exceptions
+from flax.core.frozen_dict import FrozenDict, unfreeze
+from flax.serialization import from_bytes, to_bytes
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax.random import PRNGKey
+
+from .configuration_utils import PretrainedConfig
+from .dynamic_module_utils import custom_object_save
+from .generation import FlaxGenerationMixin, GenerationConfig
+from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
+from .utils import (
+ FLAX_WEIGHTS_INDEX_NAME,
+ FLAX_WEIGHTS_NAME,
+ SAFE_WEIGHTS_INDEX_NAME,
+ SAFE_WEIGHTS_NAME,
+ WEIGHTS_INDEX_NAME,
+ WEIGHTS_NAME,
+ PushToHubMixin,
+ add_code_sample_docstrings,
+ add_start_docstrings_to_model_forward,
+ cached_file,
+ copy_func,
+ download_url,
+ has_file,
+ is_offline_mode,
+ is_remote_url,
+ logging,
+ replace_return_docstrings,
+)
+from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
+from .utils.import_utils import is_safetensors_available
+
+
+if is_safetensors_available():
+ from safetensors import safe_open
+ from safetensors.flax import load_file as safe_load_file
+ from safetensors.flax import save_file as safe_save_file
+
+logger = logging.get_logger(__name__)
+
+
+def quick_gelu(x):
+ return x * jax.nn.sigmoid(1.702 * x)
+
+
+ACT2FN = {
+ "gelu": partial(nn.gelu, approximate=False),
+ "relu": nn.relu,
+ "silu": nn.swish,
+ "swish": nn.swish,
+ "gelu_new": partial(nn.gelu, approximate=True),
+ "quick_gelu": quick_gelu,
+ "gelu_pytorch_tanh": partial(nn.gelu, approximate=True),
+ "tanh": nn.tanh,
+}
+
+
+def flax_shard_checkpoint(params, max_shard_size="10GB"):
+ """
+ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
+ given size. The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so
+ there is no optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For
+ example, if the limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as
+ [6GB], [6+2GB], [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
+
+
+
+ If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will
+ have a size greater than `max_shard_size`.
+
+
+
+ Args:
+ params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters.
+ max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
+ The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
+ (like `"5MB"`).
+ """
+ max_shard_size = convert_file_size_to_int(max_shard_size)
+
+ sharded_state_dicts = []
+ current_block = {}
+ current_block_size = 0
+ total_size = 0
+
+ # flatten the weights to chunk
+ weights = flatten_dict(params, sep="/")
+ for item in weights:
+ weight_size = weights[item].size * weights[item].dtype.itemsize
+
+ # If this weight is going to tip up over the maximal size, we split.
+ if current_block_size + weight_size > max_shard_size:
+ sharded_state_dicts.append(current_block)
+ current_block = {}
+ current_block_size = 0
+
+ current_block[item] = weights[item]
+ current_block_size += weight_size
+ total_size += weight_size
+
+ # Add the last block
+ sharded_state_dicts.append(current_block)
+
+ # If we only have one shard, we return it
+ if len(sharded_state_dicts) == 1:
+ return {FLAX_WEIGHTS_NAME: sharded_state_dicts[0]}, None
+
+ # Otherwise, let's build the index
+ weight_map = {}
+ shards = {}
+ for idx, shard in enumerate(sharded_state_dicts):
+ shard_file = FLAX_WEIGHTS_NAME.replace(".msgpack", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.msgpack")
+ shards[shard_file] = shard
+ for weight_name in shard:
+ weight_map[weight_name] = shard_file
+
+ # Add the metadata
+ metadata = {"total_size": total_size}
+ index = {"metadata": metadata, "weight_map": weight_map}
+ return shards, index
+
+
+class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
+ r"""
+ Base class for all models.
+
+ [`FlaxPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
+ downloading and saving models.
+
+ Class attributes (overridden by derived classes):
+
+ - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
+ for this model architecture.
+ - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
+ classes of the same architecture adding modules on top of the base model.
+ - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
+ models, `pixel_values` for vision models and `input_values` for speech models).
+ """
+
+ config_class = None
+ base_model_prefix = ""
+ main_input_name = "input_ids"
+ _auto_class = None
+ _missing_keys = set()
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ module: nn.Module,
+ input_shape: tuple = (1, 1),
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ ):
+ logger.warning_once(
+ "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We "
+ "recommend migrating to PyTorch classes or pinning your version of Transformers."
+ )
+ if config is None:
+ raise ValueError("config cannot be None")
+
+ if module is None:
+ raise ValueError("module cannot be None")
+
+ # Those are private to be exposed as typed property on derived classes.
+ self._config = config
+ self._module = module
+
+ # Those are public as their type is generic to every derived classes.
+ self.key = PRNGKey(seed)
+ self.dtype = dtype
+ self.input_shape = input_shape
+ self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
+
+ # To check if the model was initialized automatically.
+ self._is_initialized = _do_init
+
+ if _do_init:
+ # randomly initialized parameters
+ random_params = self.init_weights(self.key, input_shape)
+ params_shape_tree = jax.eval_shape(lambda params: params, random_params)
+ else:
+ init_fn = partial(self.init_weights, input_shape=input_shape)
+ params_shape_tree = jax.eval_shape(init_fn, self.key)
+
+ logger.info(
+ "Model weights are not initialized as `_do_init` is set to `False`. "
+ f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights."
+ )
+
+ # get the shape of the parameters
+ self._params_shape_tree = params_shape_tree
+
+ # save required_params as set
+ self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
+
+ # initialize the parameters
+ if _do_init:
+ self.params = random_params
+
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> dict:
+ raise NotImplementedError(f"init method has to be implemented for {self}")
+
+ def enable_gradient_checkpointing(self):
+ raise NotImplementedError(f"gradient checkpointing method has to be implemented for {self}")
+
+ @classmethod
+ def _from_config(cls, config, **kwargs):
+ """
+ All context managers that the model should be initialized under go here.
+ """
+ return cls(config, **kwargs)
+
+ @property
+ def framework(self) -> str:
+ """
+ :str: Identifies that this is a Flax model.
+ """
+ return "flax"
+
+ @property
+ def config(self) -> PretrainedConfig:
+ return self._config
+
+ @property
+ def module(self) -> nn.Module:
+ return self._module
+
+ @property
+ def params(self) -> Union[dict, FrozenDict]:
+ if not self._is_initialized:
+ raise ValueError(
+ "`params` cannot be accessed from model when the model is created with `_do_init=False`. "
+ "You must call `init_weights` manually and store the params outside of the model and "
+ "pass it explicitly where needed."
+ )
+ return self._params
+
+ @property
+ def required_params(self) -> set:
+ return self._required_params
+
+ @property
+ def params_shape_tree(self) -> dict:
+ return self._params_shape_tree
+
+ @params.setter
+ def params(self, params: Union[dict, FrozenDict]):
+ # don't set params if the model is not initialized
+ if not self._is_initialized:
+ raise ValueError(
+ "`params` cannot be set from model when the model is created with `_do_init=False`. "
+ "You store the params outside of the model."
+ )
+
+ if isinstance(params, FrozenDict):
+ params = unfreeze(params)
+ param_keys = set(flatten_dict(params).keys())
+ if len(self.required_params - param_keys) > 0:
+ raise ValueError(
+ "Some parameters are missing. Make sure that `params` include the following "
+ f"parameters {self.required_params - param_keys}"
+ )
+ self._params = params
+
+ def _cast_floating_to(self, params: Union[dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
+ """
+ Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
+ """
+
+ # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
+ def conditional_cast(param):
+ if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
+ param = param.astype(dtype)
+ return param
+
+ if mask is None:
+ return jax.tree_util.tree_map(conditional_cast, params)
+
+ flat_params = flatten_dict(params)
+ flat_mask, _ = jax.tree_util.tree_flatten(mask)
+
+ for masked, key in zip(flat_mask, sorted(flat_params.keys())):
+ if masked:
+ flat_params[key] = conditional_cast(flat_params[key])
+
+ return unflatten_dict(flat_params)
+
+ def to_bf16(self, params: Union[dict, FrozenDict], mask: Any = None):
+ r"""
+ Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
+ the `params` in place.
+
+ This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
+ half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
+
+ Arguments:
+ params (`Union[Dict, FrozenDict]`):
+ A `PyTree` of model parameters.
+ mask (`Union[Dict, FrozenDict]`):
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
+ you want to cast, and should be `False` for those you want to skip.
+
+ Examples:
+
+ ```python
+ >>> from transformers import FlaxBertModel
+
+ >>> # load model
+ >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
+ >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
+ >>> model.params = model.to_bf16(model.params)
+ >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
+ >>> # then pass the mask as follows
+ >>> from flax import traverse_util
+
+ >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
+ >>> flat_params = traverse_util.flatten_dict(model.params)
+ >>> mask = {
+ ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
+ ... for path in flat_params
+ ... }
+ >>> mask = traverse_util.unflatten_dict(mask)
+ >>> model.params = model.to_bf16(model.params, mask)
+ ```"""
+ return self._cast_floating_to(params, jnp.bfloat16, mask)
+
+ def to_fp32(self, params: Union[dict, FrozenDict], mask: Any = None):
+ r"""
+ Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
+ model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
+
+ Arguments:
+ params (`Union[Dict, FrozenDict]`):
+ A `PyTree` of model parameters.
+ mask (`Union[Dict, FrozenDict]`):
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
+ you want to cast, and should be `False` for those you want to skip
+
+ Examples:
+
+ ```python
+ >>> from transformers import FlaxBertModel
+
+ >>> # Download model and configuration from huggingface.co
+ >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
+ >>> # By default, the model params will be in fp32, to illustrate the use of this method,
+ >>> # we'll first cast to fp16 and back to fp32
+ >>> model.params = model.to_f16(model.params)
+ >>> # now cast back to fp32
+ >>> model.params = model.to_fp32(model.params)
+ ```"""
+ return self._cast_floating_to(params, jnp.float32, mask)
+
+ def to_fp16(self, params: Union[dict, FrozenDict], mask: Any = None):
+ r"""
+ Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
+ `params` in place.
+
+ This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full
+ half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
+
+ Arguments:
+ params (`Union[Dict, FrozenDict]`):
+ A `PyTree` of model parameters.
+ mask (`Union[Dict, FrozenDict]`):
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
+ you want to cast, and should be `False` for those you want to skip
+
+ Examples:
+
+ ```python
+ >>> from transformers import FlaxBertModel
+
+ >>> # load model
+ >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
+ >>> # By default, the model params will be in fp32, to cast these to float16
+ >>> model.params = model.to_fp16(model.params)
+ >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
+ >>> # then pass the mask as follows
+ >>> from flax import traverse_util
+
+ >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
+ >>> flat_params = traverse_util.flatten_dict(model.params)
+ >>> mask = {
+ ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
+ ... for path in flat_params
+ ... }
+ >>> mask = traverse_util.unflatten_dict(mask)
+ >>> model.params = model.to_fp16(model.params, mask)
+ ```"""
+ return self._cast_floating_to(params, jnp.float16, mask)
+
+ @classmethod
+ def load_flax_weights(cls, resolved_archive_file):
+ try:
+ if resolved_archive_file.endswith(".safetensors"):
+ state = safe_load_file(resolved_archive_file)
+ state = unflatten_dict(state, sep=".")
+ else:
+ with open(resolved_archive_file, "rb") as state_f:
+ state = from_bytes(cls, state_f.read())
+ except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
+ try:
+ with open(resolved_archive_file) as f:
+ if f.read().startswith("version"):
+ raise OSError(
+ "You seem to have cloned a repository without having git-lfs installed. Please"
+ " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
+ " folder you cloned."
+ )
+ else:
+ raise ValueError from e
+ except (UnicodeDecodeError, ValueError):
+ raise OSError(f"Unable to convert {resolved_archive_file} to Flax deserializable object. ")
+
+ return state
+
+ @classmethod
+ def load_flax_sharded_weights(cls, shard_files):
+ """
+ This is the same as [`flax.serialization.from_bytes`]
+ (https:lax.readthedocs.io/en/latest/_modules/flax/serialization.html#from_bytes) but for a sharded checkpoint.
+
+ This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
+ loaded in the model.
+
+ Args:
+ shard_files (`list[str]`:
+ The list of shard files to load.
+
+ Returns:
+ `Dict`: A nested dictionary of the model parameters, in the expected format for flax models : `{'model':
+ {'params': {'...'}}}`.
+ """
+
+ # Load the index
+ state_sharded_dict = {}
+
+ for shard_file in shard_files:
+ # load using msgpack utils
+ try:
+ with open(shard_file, "rb") as state_f:
+ state = from_bytes(cls, state_f.read())
+ except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
+ with open(shard_file) as f:
+ if f.read().startswith("version"):
+ raise OSError(
+ "You seem to have cloned a repository without having git-lfs installed. Please"
+ " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
+ " folder you cloned."
+ )
+ else:
+ raise ValueError from e
+ except (UnicodeDecodeError, ValueError):
+ raise OSError(f"Unable to convert {shard_file} to Flax deserializable object. ")
+
+ state = flatten_dict(state, sep="/")
+ state_sharded_dict.update(state)
+ del state
+ gc.collect()
+
+ # the state dict is unflattened to the match the format of model.params
+ return unflatten_dict(state_sharded_dict, sep="/")
+
+ @classmethod
+ def can_generate(cls) -> bool:
+ """
+ Returns whether this model can generate sequences with `.generate()`. Returns:
+ `bool`: Whether this model can generate sequences with `.generate()`.
+ """
+ # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
+ # Alternatively, the model can also have a custom `generate` function.
+ if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
+ return False
+ return True
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ dtype: jnp.dtype = jnp.float32,
+ *model_args,
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ ignore_mismatched_sizes: bool = False,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ token: Optional[Union[str, bool]] = None,
+ revision: str = "main",
+ **kwargs,
+ ):
+ r"""
+ Instantiate a pretrained flax model from a pre-trained model configuration.
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case,
+ `from_pt` should be set to `True`.
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+ `jax.numpy.bfloat16` (on TPUs).
+
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+ specified all the computation will be performed with the given `dtype`.
+
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
+ parameters.**
+
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
+ [`~FlaxPreTrainedModel.to_bf16`].
+ model_args (sequence of positional arguments, *optional*):
+ All remaining positional arguments will be passed to the underlying model's `__init__` method.
+ config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
+ Can be either:
+
+ - an instance of a class derived from [`PretrainedConfig`],
+ - a string or path valid as input to [`~PretrainedConfig.from_pretrained`].
+
+ Configuration for the model to use instead of an automatically loaded configuration. Configuration can
+ be automatically loaded when:
+
+ - The model is a model provided by the library (loaded with the *model id* string of a pretrained
+ model).
+ - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
+ save directory.
+ - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
+ configuration JSON file named *config.json* is found in the directory.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ from_pt (`bool`, *optional*, defaults to `False`):
+ Load the model weights from a PyTorch checkpoint save file (see docstring of
+ `pretrained_model_name_or_path` argument).
+ ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
+ Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
+ as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
+ checkpoint with 3 labels).
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+ the token generated when running `hf auth login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+
+
+
+
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`.
+
+
+
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+ `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
+ automatically loaded:
+
+ - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
+ underlying model's `__init__` method (we assume all relevant updates to the configuration have
+ already been done)
+ - If a configuration is not provided, `kwargs` will be first passed to the configuration class
+ initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
+ corresponds to a configuration attribute will be used to override said attribute with the
+ supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
+ will be passed to the underlying model's `__init__` function.
+
+ Examples:
+
+ ```python
+ >>> from transformers import BertConfig, FlaxBertModel
+
+ >>> # Download model and configuration from huggingface.co and cache.
+ >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
+ >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
+ >>> model = FlaxBertModel.from_pretrained("./test/saved_model/")
+ >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
+ >>> config = BertConfig.from_json_file("./pt_model/config.json")
+ >>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config)
+ ```"""
+ from_pt = kwargs.pop("from_pt", False)
+ resume_download = kwargs.pop("resume_download", None)
+ proxies = kwargs.pop("proxies", None)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
+ from_pipeline = kwargs.pop("_from_pipeline", None)
+ from_auto_class = kwargs.pop("_from_auto", False)
+ _do_init = kwargs.pop("_do_init", True)
+ subfolder = kwargs.pop("subfolder", "")
+ commit_hash = kwargs.pop("_commit_hash", None)
+
+ # Not relevant for Flax Models
+ _ = kwargs.pop("adapter_kwargs", None)
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ if trust_remote_code is True:
+ logger.warning(
+ "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
+ " ignored."
+ )
+
+ user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class}
+ if from_pipeline is not None:
+ user_agent["using_pipeline"] = from_pipeline
+
+ if is_offline_mode() and not local_files_only:
+ logger.info("Offline mode: forcing local_files_only=True")
+ local_files_only = True
+
+ # Load config if we don't provide a configuration
+ if not isinstance(config, PretrainedConfig):
+ config_path = config if config is not None else pretrained_model_name_or_path
+ config, model_kwargs = cls.config_class.from_pretrained(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ _from_auto=from_auto_class,
+ _from_pipeline=from_pipeline,
+ _commit_hash=commit_hash,
+ **kwargs,
+ )
+ else:
+ model_kwargs = kwargs.copy()
+
+ if commit_hash is None:
+ commit_hash = getattr(config, "_commit_hash", None)
+
+ # Add the dtype to model_kwargs
+ model_kwargs["dtype"] = dtype
+
+ # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
+ # index of the files.
+ is_sharded = False
+
+ # Load model
+ if pretrained_model_name_or_path is not None:
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if is_local:
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
+ # Load from a Flax checkpoint
+ archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)):
+ # Load from a sharded Flax checkpoint
+ archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
+ is_sharded = True
+ elif is_safetensors_available() and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME)
+ ):
+ # Load from a safetensors checkpoint
+ archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME)
+ elif is_safetensors_available() and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
+ ):
+ # Load from a safetensors checkpoint
+ archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
+ elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
+ # Load from a PyTorch checkpoint
+ archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
+ elif from_pt and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
+ ):
+ # Load from a sharded pytorch checkpoint
+ archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
+ is_sharded = True
+ # At this stage we don't have a weight file so we will raise an error.
+ elif is_safetensors_available() and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
+ ):
+ # Load from a sharded safetensors checkpoint
+ archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
+ is_sharded = True
+ raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
+ raise OSError(
+ f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
+ "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
+ "weights."
+ )
+ else:
+ raise OSError(
+ f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
+ f"{pretrained_model_name_or_path}."
+ )
+ elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
+ archive_file = pretrained_model_name_or_path
+ is_local = True
+ elif is_remote_url(pretrained_model_name_or_path):
+ filename = pretrained_model_name_or_path
+ resolved_archive_file = download_url(pretrained_model_name_or_path)
+ else:
+ if from_pt:
+ filename = WEIGHTS_NAME
+ else:
+ filename = FLAX_WEIGHTS_NAME
+
+ try:
+ # Load from URL or cache if already cached
+ cached_file_kwargs = {
+ "cache_dir": cache_dir,
+ "force_download": force_download,
+ "proxies": proxies,
+ "resume_download": resume_download,
+ "local_files_only": local_files_only,
+ "token": token,
+ "user_agent": user_agent,
+ "revision": revision,
+ "subfolder": subfolder,
+ "_raise_exceptions_for_gated_repo": False,
+ "_raise_exceptions_for_missing_entries": False,
+ "_commit_hash": commit_hash,
+ }
+ resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
+
+ # Maybe the checkpoint is sharded, we try to grab the index name in this case.
+ if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
+ resolved_archive_file = cached_file(
+ pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs
+ )
+ if resolved_archive_file is not None:
+ is_sharded = True
+
+ # Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case.
+ if resolved_archive_file is None and from_pt:
+ resolved_archive_file = cached_file(
+ pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
+ )
+ if resolved_archive_file is not None:
+ is_sharded = True
+
+ # If we still haven't found anything, look for `safetensors`.
+ if resolved_archive_file is None:
+ # No support for sharded safetensors yet, so we'll raise an error if that's all we find.
+ filename = SAFE_WEIGHTS_NAME
+ resolved_archive_file = cached_file(
+ pretrained_model_name_or_path, SAFE_WEIGHTS_NAME, **cached_file_kwargs
+ )
+
+ # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
+ # result when internet is up, the repo and revision exist, but the file does not.
+ if resolved_archive_file is None:
+ # Otherwise, maybe there is a TF or Torch model file. We try those to give a helpful error
+ # message.
+ has_file_kwargs = {
+ "revision": revision,
+ "proxies": proxies,
+ "token": token,
+ "cache_dir": cache_dir,
+ "local_files_only": local_files_only,
+ }
+ if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs):
+ is_sharded = True
+ raise NotImplementedError(
+ "Support for sharded checkpoints using safetensors is coming soon!"
+ )
+ elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
+ raise OSError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named"
+ f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
+ " load this model from those weights."
+ )
+ elif has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs):
+ raise OSError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named"
+ f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use"
+ " `from_pt=True` to load this model from those weights."
+ )
+ else:
+ raise OSError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named"
+ f" {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
+ )
+ except OSError:
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
+ # to the original exception.
+ raise
+ except Exception:
+ # For any other exception, we throw a generic error.
+ raise OSError(
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
+ " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+ f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
+ )
+
+ if is_local:
+ logger.info(f"loading weights file {archive_file}")
+ resolved_archive_file = archive_file
+ filename = resolved_archive_file.split(os.path.sep)[-1]
+ else:
+ logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
+ else:
+ resolved_archive_file = None
+
+ # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
+ if is_sharded:
+ # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
+ resolved_archive_file, _ = get_checkpoint_shard_files(
+ pretrained_model_name_or_path,
+ resolved_archive_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ token=token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ _commit_hash=commit_hash,
+ )
+
+ safetensors_from_pt = False
+ if filename == SAFE_WEIGHTS_NAME:
+ with safe_open(resolved_archive_file, framework="flax") as f:
+ safetensors_metadata = f.metadata()
+ if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]:
+ raise OSError(
+ f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
+ " Make sure you save your model with the `save_pretrained` method."
+ )
+ safetensors_from_pt = safetensors_metadata.get("format") == "pt"
+
+ # init random models
+ model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
+
+ if from_pt or safetensors_from_pt:
+ state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded)
+ else:
+ if is_sharded:
+ state = cls.load_flax_sharded_weights(resolved_archive_file)
+ else:
+ state = cls.load_flax_weights(resolved_archive_file)
+ # make sure all arrays are stored as jnp.arrays
+ # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
+ # https://github.com/google/flax/issues/1261
+ if _do_init:
+ state = jax.tree_util.tree_map(jnp.array, state)
+ else:
+ # keep the params on CPU if we don't want to initialize
+ state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state)
+
+ if "batch_stats" in state: # if flax model contains batch norm layers
+ # if model is base model only use model_prefix key
+ if (
+ cls.base_model_prefix not in dict(model.params_shape_tree["params"])
+ and cls.base_model_prefix in state["params"]
+ ):
+ state["params"] = state["params"][cls.base_model_prefix]
+ state["batch_stats"] = state["batch_stats"][cls.base_model_prefix]
+
+ # if model is head model and we are loading weights from base model
+ # we initialize new params dict with base_model_prefix
+ if (
+ cls.base_model_prefix in dict(model.params_shape_tree["params"])
+ and cls.base_model_prefix not in state["params"]
+ ):
+ state = {
+ "params": {cls.base_model_prefix: state["params"]},
+ "batch_stats": {cls.base_model_prefix: state["batch_stats"]},
+ }
+
+ else:
+ # if model is base model only use model_prefix key
+ if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state:
+ state = state[cls.base_model_prefix]
+
+ # if model is head model and we are loading weights from base model
+ # we initialize new params dict with base_model_prefix
+ if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state:
+ state = {cls.base_model_prefix: state}
+
+ # flatten dicts
+ state = flatten_dict(state)
+
+ random_state = flatten_dict(unfreeze(model.params if _do_init else model.params_shape_tree))
+
+ missing_keys = model.required_params - set(state.keys())
+ unexpected_keys = set(state.keys()) - model.required_params
+
+ # Disabling warning when porting pytorch weights to flax, flax does not uses num_batches_tracked
+ for unexpected_key in unexpected_keys.copy():
+ if "num_batches_tracked" in unexpected_key[-1]:
+ unexpected_keys.remove(unexpected_key)
+
+ if missing_keys and not _do_init:
+ logger.warning(
+ f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
+ "Make sure to call model.init_weights to initialize the missing weights."
+ )
+ cls._missing_keys = missing_keys
+
+ # Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
+ # matching the weights in the model.
+ mismatched_keys = []
+ for key in state:
+ if key in random_state and state[key].shape != random_state[key].shape:
+ if ignore_mismatched_sizes:
+ mismatched_keys.append((key, state[key].shape, random_state[key].shape))
+ state[key] = random_state[key]
+ else:
+ raise ValueError(
+ f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
+ f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
+ "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
+ "model."
+ )
+
+ # add missing keys as random parameters if we are initializing
+ if missing_keys and _do_init:
+ for missing_key in missing_keys:
+ state[missing_key] = random_state[missing_key]
+
+ # remove unexpected keys to not be saved again
+ for unexpected_key in unexpected_keys:
+ del state[unexpected_key]
+
+ if len(unexpected_keys) > 0:
+ logger.warning(
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
+ " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
+ " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
+ )
+ else:
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
+
+ if len(missing_keys) > 0:
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+ )
+ elif len(mismatched_keys) == 0:
+ logger.info(
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
+ f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
+ " training."
+ )
+ if len(mismatched_keys) > 0:
+ mismatched_warning = "\n".join(
+ [
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
+ for key, shape1, shape2 in mismatched_keys
+ ]
+ )
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
+ " to use it for predictions and inference."
+ )
+
+ # dictionary of key: dtypes for the model params
+ param_dtypes = jax.tree_util.tree_map(lambda x: x.dtype, state)
+ # extract keys of parameters not in jnp.float32
+ fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16]
+ bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16]
+
+ # raise a warning if any of the parameters are not in jnp.float32
+ if len(fp16_params) > 0:
+ logger.warning(
+ f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from "
+ f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n"
+ "You should probably UPCAST the model weights to float32 if this was not intended. "
+ "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
+ )
+
+ if len(bf16_params) > 0:
+ logger.warning(
+ f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from "
+ f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n"
+ "You should probably UPCAST the model weights to float32 if this was not intended. "
+ "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
+ )
+
+ # If it is a model with generation capabilities, attempt to load the generation config
+ if model.can_generate():
+ try:
+ model.generation_config = GenerationConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ _from_auto=from_auto_class,
+ _from_pipeline=from_pipeline,
+ **kwargs,
+ )
+ except OSError:
+ logger.info(
+ "Generation config file not found, using a generation config created from the model config."
+ )
+ pass
+
+ if _do_init:
+ # set correct parameters
+ model.params = unflatten_dict(state)
+ return model
+ else:
+ return model, unflatten_dict(state)
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ params=None,
+ push_to_hub=False,
+ max_shard_size="10GB",
+ token: Optional[Union[str, bool]] = None,
+ safe_serialization: bool = False,
+ **kwargs,
+ ):
+ """
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
+ `[`~FlaxPreTrainedModel.from_pretrained`]` class method
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
+ The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
+ lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
+
+
+
+ If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
+ which will be bigger than `max_shard_size`.
+
+
+
+ token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+ the token generated when running `hf auth login` (stored in `~/.huggingface`).
+ kwargs (`dict[str, Any]`, *optional*):
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ safe_serialization (`bool`, *optional*, defaults to `False`):
+ Whether to save the model using `safetensors` or through msgpack.
+ """
+ use_auth_token = kwargs.pop("use_auth_token", None)
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ if token is not None:
+ kwargs["token"] = token
+
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ if push_to_hub:
+ commit_message = kwargs.pop("commit_message", None)
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+ repo_id = self._create_repo(repo_id, **kwargs)
+ files_timestamps = self._get_files_timestamps(save_directory)
+
+ # get abs dir
+ save_directory = os.path.abspath(save_directory)
+ # save config as well
+ self.config.architectures = [self.__class__.__name__[4:]]
+
+ # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
+ # loaded from the Hub.
+ if self._auto_class is not None:
+ custom_object_save(self, save_directory, config=self.config)
+
+ self.config.save_pretrained(save_directory)
+ if self.can_generate():
+ self.generation_config.save_pretrained(save_directory)
+
+ # save model
+ weights_name = SAFE_WEIGHTS_NAME if safe_serialization else FLAX_WEIGHTS_NAME
+ output_model_file = os.path.join(save_directory, weights_name)
+
+ shards, index = flax_shard_checkpoint(params if params is not None else self.params, max_shard_size)
+ # Clean the folder from a previous save
+ for filename in os.listdir(save_directory):
+ full_filename = os.path.join(save_directory, filename)
+ weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
+ if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and filename not in shards:
+ os.remove(full_filename)
+
+ if index is None:
+ if safe_serialization:
+ params = params if params is not None else self.params
+ flat_dict = flatten_dict(params, sep=".")
+ safe_save_file(flat_dict, output_model_file, metadata={"format": "flax"})
+ else:
+ with open(output_model_file, "wb") as f:
+ params = params if params is not None else self.params
+ model_bytes = to_bytes(params)
+ f.write(model_bytes)
+
+ else:
+ save_index_file = os.path.join(save_directory, FLAX_WEIGHTS_INDEX_NAME)
+ # Save the index as well
+ with open(save_index_file, "w", encoding="utf-8") as f:
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
+ f.write(content)
+ logger.info(
+ f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
+ f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+ for shard_file, shard in shards.items():
+ # the shard item are unflattened, to save them we need to flatten them again
+ with open(os.path.join(save_directory, shard_file), mode="wb") as f:
+ params = unflatten_dict(shard, sep="/")
+ shard_bytes = to_bytes(params)
+ f.write(shard_bytes)
+
+ logger.info(f"Model weights saved in {output_model_file}")
+
+ if push_to_hub:
+ self._upload_modified_files(
+ save_directory,
+ repo_id,
+ files_timestamps,
+ commit_message=commit_message,
+ token=token,
+ )
+
+ @classmethod
+ def register_for_auto_class(cls, auto_class="FlaxAutoModel"):
+ """
+ Register this class with a given auto class. This should only be used for custom models as the ones in the
+ library are already mapped with an auto class.
+
+
+
+ Args:
+ auto_class (`str` or `type`, *optional*, defaults to `"FlaxAutoModel"`):
+ The auto class to register this new model with.
+ """
+ if not isinstance(auto_class, str):
+ auto_class = auto_class.__name__
+
+ import transformers.models.auto as auto_module
+
+ if not hasattr(auto_module, auto_class):
+ raise ValueError(f"{auto_class} is not a valid auto class.")
+
+ cls._auto_class = auto_class
+
+
+# To update the docstring, we need to copy the method, otherwise we change the original docstring.
+FlaxPreTrainedModel.push_to_hub = copy_func(FlaxPreTrainedModel.push_to_hub)
+if FlaxPreTrainedModel.push_to_hub.__doc__ is not None:
+ FlaxPreTrainedModel.push_to_hub.__doc__ = FlaxPreTrainedModel.push_to_hub.__doc__.format(
+ object="model", object_class="FlaxAutoModel", object_files="model checkpoint"
+ )
+
+
+def overwrite_call_docstring(model_class, docstring):
+ # copy __call__ function to be sure docstring is changed only for this function
+ model_class.__call__ = copy_func(model_class.__call__)
+ # delete existing docstring
+ model_class.__call__.__doc__ = None
+ # set correct docstring
+ model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
+
+
+def append_call_sample_docstring(
+ model_class, checkpoint, output_type, config_class, mask=None, revision=None, real_checkpoint=None
+):
+ model_class.__call__ = copy_func(model_class.__call__)
+ model_class.__call__ = add_code_sample_docstrings(
+ checkpoint=checkpoint,
+ output_type=output_type,
+ config_class=config_class,
+ model_cls=model_class.__name__,
+ revision=revision,
+ real_checkpoint=real_checkpoint,
+ )(model_class.__call__)
+
+
+def append_replace_return_docstrings(model_class, output_type, config_class):
+ model_class.__call__ = copy_func(model_class.__call__)
+ model_class.__call__ = replace_return_docstrings(
+ output_type=output_type,
+ config_class=config_class,
+ )(model_class.__call__)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_gguf_pytorch_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_gguf_pytorch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..08aaac3617ff65b5f9e8306bb04ea454e20be5ee
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_gguf_pytorch_utils.py
@@ -0,0 +1,532 @@
+# Copyright 2024 The ggml.ai team and The HuggingFace Inc. team. and pygguf author (github.com/99991)
+# https://github.com/99991/pygguf
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from typing import NamedTuple, Optional
+
+import numpy as np
+from tqdm.auto import tqdm
+
+from .integrations import (
+ GGUF_CONFIG_MAPPING,
+ GGUF_TOKENIZER_MAPPING,
+ _gguf_parse_value,
+)
+from .utils import is_torch_available
+from .utils.import_utils import is_gguf_available
+from .utils.logging import get_logger
+
+
+if is_torch_available():
+ import torch
+
+logger = get_logger(__name__)
+
+
+GGUF_TO_TRANSFORMERS_MAPPING = {
+ "ignore": {
+ "GGUF": {
+ "version": "version",
+ "tensor_count": "tensor_count",
+ "kv_count": "kv_count",
+ },
+ "general": {"file_type": "file_type", "quantization_version": "quantization_version"},
+ },
+ "config": GGUF_CONFIG_MAPPING,
+ "tokenizer": {"tokenizer": GGUF_TOKENIZER_MAPPING["tokenizer"]},
+ "tokenizer_config": {"tokenizer": GGUF_TOKENIZER_MAPPING["tokenizer_config"]},
+}
+
+GGUF_SUPPORTED_ARCHITECTURES = list(GGUF_TO_TRANSFORMERS_MAPPING["config"].keys())
+
+
+class GGUFTensor(NamedTuple):
+ weights: np.ndarray
+ name: str
+ metadata: dict
+
+
+class TensorProcessor:
+ def __init__(self, config=None):
+ self.config = config or {}
+
+ def process(self, weights, name, **kwargs):
+ return GGUFTensor(weights, name, {})
+
+
+class LlamaTensorProcessor(TensorProcessor):
+ def __init__(self, config=None):
+ super().__init__(config=config)
+
+ def process(self, weights, name, **kwargs):
+ if ".attn_k." in name or ".attn_q." in name:
+ num_heads = self.config.get("num_attention_heads")
+ num_kv_heads = self.config.get("num_key_value_heads")
+
+ if None in (num_heads, num_kv_heads):
+ return GGUFTensor(weights, name, {})
+ if ".attn_q." in name:
+ weights = self._reverse_permute_weights(weights, num_heads, num_heads)
+ elif ".attn_k." in name:
+ weights = self._reverse_permute_weights(weights, num_heads, num_kv_heads)
+ return GGUFTensor(weights, name, {})
+
+ def _reverse_permute_weights(
+ self, weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None
+ ) -> np.ndarray:
+ # Original permutation implementation
+ # https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408
+ if num_kv_heads is not None and n_head != num_kv_heads:
+ n_head = num_kv_heads
+
+ dim = weights.shape[0] // n_head // 2
+ w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
+ return w.swapaxes(2, 1).reshape(weights.shape)
+
+
+class Qwen2MoeTensorProcessor(TensorProcessor):
+ def __init__(self, config=None):
+ super().__init__(config=config)
+
+ def process(self, weights, name, **kwargs):
+ if "_exp" in name:
+ tensor_key_mapping = kwargs.get("tensor_key_mapping")
+ parsed_parameters = kwargs.get("parsed_parameters")
+ if tensor_key_mapping:
+ self._split_moe_expert_tensor(weights, parsed_parameters, name, tensor_key_mapping)
+ return GGUFTensor(weights, None, {})
+ if "ffn_gate_inp_shexp" in name:
+ # for compatibility tensor shared_expert_gate must be (1, 2048) dim,
+ # quantized one is (2048)
+ weights = np.expand_dims(weights, axis=0)
+ return GGUFTensor(weights, name, {})
+
+ def _split_moe_expert_tensor(
+ self, weights: np.ndarray, parsed_parameters: dict[str, dict], name: str, tensor_key_mapping: dict
+ ):
+ # Original merge implementation
+ # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L1994-L2022
+ name = tensor_key_mapping[name]
+ w_counter = self.config.get("num_experts", 60)
+ for i in range(0, w_counter):
+ temp_name = name.replace("mlp.experts.", f"mlp.experts.{i}.")
+ exp_weight = weights[i]
+ parsed_parameters["tensors"][temp_name] = torch.from_numpy(np.copy(exp_weight))
+
+
+class BloomTensorProcessor(TensorProcessor):
+ def __init__(self, config=None):
+ super().__init__(config=config)
+
+ def process(self, weights, name, **kwargs):
+ if "attn_qkv" in name:
+ num_heads = self.config["n_head"]
+ n_embed = self.config["hidden_size"]
+ if "weight" in name:
+ weights = self._reverse_reshape_weights(weights, num_heads, n_embed)
+ else:
+ weights = self._reverse_reshape_bias(weights, num_heads, n_embed)
+ return GGUFTensor(weights, name, {})
+
+ def _reverse_reshape_weights(self, weights: np.ndarray, n_head: int, n_embed: int):
+ # Original reshape implementation
+ # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L972-L985
+ q, k, v = np.array_split(weights, 3, axis=0)
+
+ q = q.reshape(n_head, n_embed // n_head, n_embed)
+ k = k.reshape(n_head, n_embed // n_head, n_embed)
+ v = v.reshape(n_head, n_embed // n_head, n_embed)
+ qkv_weights = np.stack([q, k, v], axis=1)
+
+ return qkv_weights.reshape(n_head * 3 * (n_embed // n_head), n_embed)
+
+ def _reverse_reshape_bias(self, weights: np.ndarray, n_head: int, n_embed: int):
+ # Original reshape implementation
+ # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L986-L998
+ q_bias, k_bias, v_bias = np.array_split(weights, 3)
+
+ q_bias = q_bias.reshape(n_head, n_embed // n_head)
+ k_bias = k_bias.reshape(n_head, n_embed // n_head)
+ v_bias = v_bias.reshape(n_head, n_embed // n_head)
+
+ qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten()
+ return qkv_bias
+
+
+class T5TensorProcessor(TensorProcessor):
+ def __init__(self, config=None):
+ super().__init__(config=config)
+
+ def process(self, weights, name, **kwargs):
+ bid = None
+ for chunk in name.split("."):
+ if chunk.isdigit():
+ bid = int(chunk)
+ break
+ return GGUFTensor(weights, name, {"bid": bid})
+
+
+class GPT2TensorProcessor(TensorProcessor):
+ def __init__(self, config=None):
+ super().__init__(config=config)
+
+ def process(self, weights, name, **kwargs):
+ # Original transpose implementation
+ # https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L2060-L2061
+ if (
+ "attn_qkv.weight" in name
+ or "ffn_down.weight" in name
+ or "ffn_up.weight" in name
+ or "attn_output.weight" in name
+ ):
+ weights = weights.T
+
+ # Handle special case for output.weight
+ if name == "output.weight":
+ # output.weight has conflicts with attn_output.weight in name checking
+ # Store the tensor directly and signal to skip further processing
+ name = "lm_head.weight"
+ parsed_parameters = kwargs.get("parsed_parameters", {})
+ parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
+ name = None # Signal to skip further processing
+ return GGUFTensor(weights, name, {})
+
+
+class MambaTensorProcessor(TensorProcessor):
+ def __init__(self, config=None):
+ super().__init__(config=config)
+
+ def process(self, weights, name, **kwargs):
+ if "ssm_conv1d.weight" in name:
+ # for compatibility tensor ssm_conv1d must be (5120, 1, 4]) dim,
+ # quantized one is (5120, 4)
+ weights = np.expand_dims(weights, axis=1)
+ if "ssm_a" in name:
+ # Original exponential implementation
+ # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L2975-L2977
+ weights = np.log(-weights)
+ return GGUFTensor(weights, name, {})
+
+
+class NemotronTensorProcessor(TensorProcessor):
+ def __init__(self, config=None):
+ super().__init__(config=config)
+
+ # ref : https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L4666
+ def process(self, weights, name, **kwargs):
+ if "norm.weight" in name:
+ weights = weights - 1
+ return GGUFTensor(weights, name, {})
+
+
+class Gemma2TensorProcessor(TensorProcessor):
+ def __init__(self, config=None):
+ super().__init__(config=config)
+
+ # ref: https://github.com/ggerganov/llama.cpp/blob/d79d8f39b4da6deca4aea8bf130c6034c482b320/convert_hf_to_gguf.py#L3191
+ # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
+ def process(self, weights, name, **kwargs):
+ if "norm.weight" in name:
+ weights = weights - 1
+ return GGUFTensor(weights, name, {})
+
+
+class Lfm2TensorProcessor(TensorProcessor):
+ def __init__(self, config=None):
+ super().__init__(config=config)
+
+ def process(self, weights, name, **kwargs):
+ if "shortconv.conv.weight" in name:
+ ## GGUF shape is [hidden_dim, L_cache], HF expects [hidden_dim, 1, L_cache]
+ weights = np.expand_dims(weights, axis=1) ## equivalent to unsqueeze(1)
+ return GGUFTensor(weights, name, {})
+
+
+TENSOR_PROCESSORS = {
+ "llama": LlamaTensorProcessor,
+ "qwen2moe": Qwen2MoeTensorProcessor,
+ "qwen3moe": Qwen2MoeTensorProcessor,
+ "bloom": BloomTensorProcessor,
+ "t5": T5TensorProcessor,
+ "t5encoder": T5TensorProcessor,
+ "gpt2": GPT2TensorProcessor,
+ "mamba": MambaTensorProcessor,
+ "nemotron": NemotronTensorProcessor,
+ "gemma2": Gemma2TensorProcessor,
+ "gemma3": Gemma2TensorProcessor,
+ "lfm2": Lfm2TensorProcessor,
+}
+
+
+def read_field(reader, field):
+ if field not in reader.fields:
+ return []
+ value = reader.fields[field]
+ return [_gguf_parse_value(value.parts[_data_index], value.types) for _data_index in value.data]
+
+
+# modified from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/loader.py#L1115-L1147
+def get_gguf_hf_weights_map(
+ hf_model,
+ model_type: Optional[str] = None,
+ num_layers: Optional[int] = None,
+ qual_name: str = "",
+):
+ """
+ GGUF uses this naming convention for their tensors from HF checkpoint:
+ `blk.N.BB.weight` and `blk.N.BB.bias`
+ where N signifies the block number of a layer, and BB signifies the
+ attention/mlp layer components.
+ See "Standardized tensor names" in
+ https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
+ """
+ if is_gguf_available() and is_torch_available():
+ from gguf import MODEL_ARCH_NAMES, get_tensor_name_map
+ else:
+ logger.error(
+ "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
+ "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
+ )
+ raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
+
+ model_type = hf_model.config.model_type if model_type is None else model_type
+ num_layers = hf_model.config.num_hidden_layers if num_layers is None else num_layers
+ # hack: ggufs have a different name for cohere
+ if model_type == "cohere":
+ model_type = "command-r"
+ elif model_type == "qwen2_moe":
+ model_type = "qwen2moe"
+ elif model_type == "qwen3_moe":
+ model_type = "qwen3moe"
+ elif model_type == "gemma3_text":
+ model_type = "gemma3"
+ elif model_type == "umt5":
+ model_type = "t5"
+ arch = None
+ for key, value in MODEL_ARCH_NAMES.items():
+ if value == model_type:
+ arch = key
+ break
+ if arch is None:
+ raise NotImplementedError(
+ f"Unknown gguf model_type: {model_type} in gguf-py. "
+ "This might because you're using an outdated version of gguf-py package, "
+ "you can install `gguf` package from source refer to "
+ "https://github.com/ggerganov/llama.cpp/tree/master/gguf-py#development"
+ )
+ name_map = get_tensor_name_map(arch, num_layers)
+
+ # Use a dummy conversion to get the mapping, because
+ # hf => gguf and gguf => hf mappings are reversed
+ gguf_to_hf_name_map = {}
+ state_dict = hf_model.state_dict()
+ for hf_name in state_dict:
+ # An exception for qwen2moe/qwen3moe model, where the expert layers are packed
+ if model_type in ("qwen2moe", "qwen3moe") and "mlp.experts." in hf_name:
+ hf_name = re.sub(r"mlp.experts.\d+.", "mlp.experts.", hf_name)
+
+ name, suffix = hf_name, ""
+ if hf_name.endswith(".weight") or hf_name.endswith(".bias"):
+ name, suffix = hf_name.rsplit(".", 1)
+ suffix = "." + suffix
+
+ gguf_name = name_map.get_name(name)
+ if gguf_name is None:
+ continue
+
+ gguf_to_hf_name_map[gguf_name + suffix] = qual_name + hf_name
+
+ # Some model like Bloom converted from BloomModel instead of BloomForCausalLM
+ # Therefore, we need to check submodule as well to get a correct mapping
+ if named_children := hf_model.named_children():
+ for name, child in named_children:
+ sub_map = get_gguf_hf_weights_map(child, model_type, num_layers, qual_name=f"{qual_name}{name}.")
+ # Ignore the keys that are already in the main map to avoid overwriting
+ sub_map = {k: v for k, v in sub_map.items() if k not in gguf_to_hf_name_map}
+ gguf_to_hf_name_map.update(sub_map)
+
+ return gguf_to_hf_name_map
+
+
+def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_load=None):
+ """
+ Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed
+ tokenizer and config attributes.
+
+ Args:
+ gguf_checkpoint_path (`str`):
+ The path the to GGUF file to load
+ return_tensors (`bool`, defaults to `False`):
+ Whether to read the tensors from the file and return them. Not doing so is faster
+ and only loads the metadata in memory.
+ """
+ if is_gguf_available() and is_torch_available():
+ from gguf import GGUFReader, dequantize
+ else:
+ logger.error(
+ "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
+ "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
+ )
+ raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
+
+ reader = GGUFReader(gguf_checkpoint_path)
+ fields = reader.fields
+ reader_keys = list(fields.keys())
+
+ parsed_parameters = {k: {} for k in GGUF_TO_TRANSFORMERS_MAPPING}
+
+ architecture = read_field(reader, "general.architecture")[0]
+ # NOTE: Some GGUF checkpoints may miss `general.name` field in metadata
+ model_name = read_field(reader, "general.name")
+
+ updated_architecture = None
+ # in llama.cpp mistral models use the same architecture as llama. We need
+ # to add this patch to ensure things work correctly on our side.
+ if "llama" in architecture and "mistral" in model_name:
+ updated_architecture = "mistral"
+ # FIXME: Currently this implementation is only for flan-t5 architecture.
+ # It needs to be developed for supporting legacy t5.
+ elif "t5" in architecture or "t5encoder" in architecture:
+ parsed_parameters["config"]["is_gated_act"] = True
+ if model_name and "umt5" in model_name[0].lower():
+ updated_architecture = "umt5"
+ if "t5encoder" in architecture:
+ parsed_parameters["config"]["architectures"] = ["UMT5EncoderModel"]
+ else:
+ if "t5encoder" in architecture:
+ parsed_parameters["config"]["architectures"] = ["T5EncoderModel"]
+ updated_architecture = "t5"
+ else:
+ updated_architecture = architecture
+
+ if "qwen2moe" in architecture:
+ updated_architecture = "qwen2_moe"
+ elif "qwen3moe" in architecture:
+ updated_architecture = "qwen3_moe"
+
+ # For stablelm architecture, we need to set qkv_bias and use_parallel_residual from tensors
+ # If `qkv_bias=True`, qkv_proj with bias will be present in the tensors
+ # If `use_parallel_residual=False`, ffn_norm will be present in the tensors
+ if "stablelm" in architecture:
+ attn_bias_name = {"attn_q.bias", "attn_k.bias", "attn_v.bias"}
+ ffn_norm_name = "ffn_norm"
+ qkv_bias = any(bias_name in tensor.name for tensor in reader.tensors for bias_name in attn_bias_name)
+ use_parallel_residual = any(ffn_norm_name in tensor.name for tensor in reader.tensors)
+ parsed_parameters["config"]["use_qkv_bias"] = qkv_bias
+ parsed_parameters["config"]["use_parallel_residual"] = not use_parallel_residual
+
+ if architecture not in GGUF_SUPPORTED_ARCHITECTURES and updated_architecture not in GGUF_SUPPORTED_ARCHITECTURES:
+ raise ValueError(f"GGUF model with architecture {architecture} is not supported yet.")
+
+ # Handle tie_word_embeddings, if lm_head.weight is not present in tensors,
+ # tie_word_embeddings is true otherwise false
+ exceptions = ["falcon", "bloom"]
+ parsed_parameters["config"]["tie_word_embeddings"] = (
+ all("output.weight" != tensor.name for tensor in reader.tensors) or architecture in exceptions
+ )
+
+ # List all key-value pairs in a columnized format
+ for gguf_key, field in reader.fields.items():
+ gguf_key = gguf_key.replace(architecture, updated_architecture)
+ split = gguf_key.split(".")
+ prefix = split[0]
+ config_key = ".".join(split[1:])
+
+ value = [_gguf_parse_value(field.parts[_data_index], field.types) for _data_index in field.data]
+
+ if len(value) == 1:
+ value = value[0]
+
+ if isinstance(value, str) and architecture in value:
+ value = value.replace(architecture, updated_architecture)
+
+ for parameter, parameter_renames in GGUF_TO_TRANSFORMERS_MAPPING.items():
+ if prefix in parameter_renames and config_key in parameter_renames[prefix]:
+ renamed_config_key = parameter_renames[prefix][config_key]
+ if renamed_config_key == -1:
+ continue
+
+ if renamed_config_key is not None:
+ parsed_parameters[parameter][renamed_config_key] = value
+
+ if gguf_key in reader_keys:
+ reader_keys.remove(gguf_key)
+
+ if gguf_key in reader_keys:
+ logger.info(f"Some keys were not parsed and added into account {gguf_key} | {value}")
+
+ # Gemma3 GGUF checkpoint only contains weights of text backbone
+ if parsed_parameters["config"]["model_type"] == "gemma3":
+ parsed_parameters["config"]["model_type"] = "gemma3_text"
+
+ if parsed_parameters["config"]["model_type"] == "lfm2":
+ gguf_num_key_value_heads = parsed_parameters["config"]["num_key_value_heads"]
+ # LFM2 GGUF checkpoint defines num_key_value_heads as a list of integers .e.g [0, 0, 8, 0, 0, 8, 0, 0, 8, 0, 8, 0, 8, 0, 8, 0] but we need to set it to the max value for HF
+ parsed_parameters["config"]["num_key_value_heads"] = max(gguf_num_key_value_heads)
+ ## we already read the correct intermediate_size from the GGUF checkpoint so we need to set block_auto_adjust_ff_dim to False
+ parsed_parameters["config"]["block_auto_adjust_ff_dim"] = False
+
+ ## llama.cpp defines the layers that are full-attention by looking at num_key_value_heads
+ ## we need to set the full_attn_idxs to the layers that are full-attention
+ parsed_parameters["config"]["full_attn_idxs"] = [
+ i for i, num_kv_heads in enumerate(gguf_num_key_value_heads) if num_kv_heads > 0
+ ]
+
+ # retrieve config vocab_size from tokenizer
+ # Please refer to https://github.com/huggingface/transformers/issues/32526 for more details
+ if "vocab_size" not in parsed_parameters["config"]:
+ tokenizer_parameters = parsed_parameters["tokenizer"]
+ if "tokens" in tokenizer_parameters:
+ parsed_parameters["config"]["vocab_size"] = len(tokenizer_parameters["tokens"])
+ else:
+ logger.warning(
+ "Can't find a way to retrieve missing config vocab_size from tokenizer parameters. "
+ "This will use default value from model config class and cause unexpected behavior."
+ )
+
+ if return_tensors:
+ parsed_parameters["tensors"] = {}
+
+ tensor_key_mapping = get_gguf_hf_weights_map(model_to_load)
+ config = parsed_parameters.get("config", {})
+
+ ProcessorClass = TENSOR_PROCESSORS.get(architecture, TensorProcessor)
+ processor = ProcessorClass(config=config)
+
+ for tensor in tqdm(reader.tensors, desc="Converting and de-quantizing GGUF tensors..."):
+ name = tensor.name
+ weights = dequantize(tensor.data, tensor.tensor_type)
+
+ result = processor.process(
+ weights=weights,
+ name=name,
+ tensor_key_mapping=tensor_key_mapping,
+ parsed_parameters=parsed_parameters,
+ )
+
+ weights = result.weights
+ name = result.name
+
+ if name not in tensor_key_mapping:
+ continue
+
+ name = tensor_key_mapping[name]
+
+ parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
+
+ if len(reader_keys) > 0:
+ logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}")
+
+ return parsed_parameters
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_outputs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_outputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..1747f6fa477b56bc349f2c535705382ebc84af3c
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_outputs.py
@@ -0,0 +1,1715 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+
+from .cache_utils import Cache, EncoderDecoderCache
+from .utils import ModelOutput
+
+
+@dataclass
+class BaseModelOutput(ModelOutput):
+ """
+ Base class for model's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithNoAttention(ModelOutput):
+ """
+ Base class for model's outputs, with potential hidden states.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, num_channels, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithPooling(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
+ Last layer hidden-state of the first token of the sequence (classification token) after further processing
+ through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
+ the classification token after processing through a linear layer and a tanh activation function. The linear
+ layer weights are trained from the next sentence prediction (classification) objective during pretraining.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ pooler_output: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithPoolingAndNoAttention(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
+ Last layer hidden-state after a pooling operation on the spatial dimensions.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, num_channels, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ pooler_output: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithPast(ModelOutput):
+ """
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+ input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithCrossAttentions(ModelOutput):
+ """
+ Base class for model's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
+ Last layer hidden-state of the first token of the sequence (classification token) after further processing
+ through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
+ the classification token after processing through a linear layer and a tanh activation function. The linear
+ layer weights are trained from the next sentence prediction (classification) objective during pretraining.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+ input) to speed up sequential decoding.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ pooler_output: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ past_key_values: Optional[Cache] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
+ """
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+ input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class MoECausalLMOutputWithPast(ModelOutput):
+ """
+ Base class for causal language model (or autoregressive) outputs as well as Mixture of Expert's router hidden
+ states terms, to train a MoE model.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ z_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
+ z_loss for the sparse modules.
+ aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
+ aux_loss for the sparse modules.
+ router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+ Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse
+ modules.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ z_loss: Optional[torch.FloatTensor] = None
+ aux_loss: Optional[torch.FloatTensor] = None
+ router_logits: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class MoEModelOutput(ModelOutput):
+ """
+ Base class for model's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+ Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary
+ loss and the z_loss for Mixture of Experts models.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ router_probs: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class MoeModelOutputWithPast(ModelOutput):
+ """
+ Base class for model's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+ input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+ Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary
+ loss for Mixture of Experts models.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ router_logits: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class MoeCausalLMOutputWithPast(ModelOutput):
+ """
+ Base class for causal language model (or autoregressive) with mixture of experts outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+
+ aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
+ aux_loss for the sparse modules.
+
+ router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+ Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary
+ loss for Mixture of Experts models.
+
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ aux_loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ router_logits: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class MoEModelOutputWithPastAndCrossAttentions(ModelOutput):
+ """
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding) as well as
+ Mixture of Expert's router hidden states terms, to train a MoE model.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+ input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+ Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary
+ loss and the z_loss for Mixture of Experts models.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ router_probs: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class Seq2SeqModelOutput(ModelOutput):
+ """
+ Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
+ decoding.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[EncoderDecoderCache] = None
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Seq2SeqMoEModelOutput(ModelOutput):
+ """
+ Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
+ decoding.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+ Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+ Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse
+ modules.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[EncoderDecoderCache] = None
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ decoder_router_logits: Optional[tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ encoder_router_logits: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class CausalLMOutput(ModelOutput):
+ """
+ Base class for causal language model (or autoregressive) outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class CausalLMOutputWithPast(ModelOutput):
+ """
+ Base class for causal language model (or autoregressive) outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class CausalLMOutputWithCrossAttentions(ModelOutput):
+ """
+ Base class for causal language model (or autoregressive) outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Cross attentions weights after the attention softmax, used to compute the weighted average in the
+ cross-attention heads.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class SequenceClassifierOutputWithPast(ModelOutput):
+ """
+ Base class for outputs of sentence classification models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class MaskedLMOutput(ModelOutput):
+ """
+ Base class for masked language models outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Masked language modeling (MLM) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Seq2SeqLMOutput(ModelOutput):
+ """
+ Base class for sequence-to-sequence language models outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[EncoderDecoderCache] = None
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Seq2SeqMoEOutput(ModelOutput):
+ """
+ Base class for sequence-to-sequence language models outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+ Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+ Router logits of the encoder model, useful to compute the auxiliary loss and z_loss for Mixture of Experts
+ models.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ encoder_z_loss: Optional[torch.FloatTensor] = None
+ decoder_z_loss: Optional[torch.FloatTensor] = None
+ encoder_aux_loss: Optional[torch.FloatTensor] = None
+ decoder_aux_loss: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[EncoderDecoderCache] = None
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ decoder_router_logits: Optional[tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ encoder_router_logits: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class NextSentencePredictorOutput(ModelOutput):
+ """
+ Base class for outputs of models predicting if two sentences are consecutive or not.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `next_sentence_label` is provided):
+ Next sequence prediction (classification) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
+ before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class SequenceClassifierOutput(ModelOutput):
+ """
+ Base class for outputs of sentence classification models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Seq2SeqSequenceClassifierOutput(ModelOutput):
+ """
+ Base class for outputs of sequence-to-sequence sentence classification models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `label` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[EncoderDecoderCache] = None
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class MultipleChoiceModelOutput(ModelOutput):
+ """
+ Base class for outputs of multiple choice models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
+ Classification loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
+ *num_choices* is the second dimension of the input tensors. (see *input_ids* above).
+
+ Classification scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class TokenClassifierOutput(ModelOutput):
+ """
+ Base class for outputs of token classification models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
+ Classification loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
+ Classification scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class QuestionAnsweringModelOutput(ModelOutput):
+ """
+ Base class for outputs of question answering models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
+ start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Span-start scores (before SoftMax).
+ end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Span-end scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ start_logits: Optional[torch.FloatTensor] = None
+ end_logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
+ """
+ Base class for outputs of sequence-to-sequence question answering models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
+ start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Span-start scores (before SoftMax).
+ end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Span-end scores (before SoftMax).
+ past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ start_logits: Optional[torch.FloatTensor] = None
+ end_logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[EncoderDecoderCache] = None
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class SemanticSegmenterOutput(ModelOutput):
+ """
+ Base class for outputs of semantic segmentation models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
+ Classification scores for each pixel.
+
+
+
+ The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
+ to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
+ original image size as post-processing. You should always check your logits shape and resize as needed.
+
+
+
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, patch_size, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class ImageClassifierOutput(ModelOutput):
+ """
+ Base class for outputs of image classification models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
+ (also called feature maps) of the model at the output of each stage.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class ImageClassifierOutputWithNoAttention(ModelOutput):
+ """
+ Base class for outputs of image classification models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
+ called feature maps) of the model at the output of each stage.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class DepthEstimatorOutput(ModelOutput):
+ """
+ Base class for outputs of depth estimation models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ predicted_depth (`torch.FloatTensor` of shape `(batch_size, height, width)`):
+ Predicted depth for each pixel.
+
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, num_channels, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ predicted_depth: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class ImageSuperResolutionOutput(ModelOutput):
+ """
+ Base class for outputs of image super resolution models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Reconstruction loss.
+ reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Reconstructed images, possibly upscaled.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
+ (also called feature maps) of the model at the output of each stage.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ reconstruction: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Wav2Vec2BaseModelOutput(ModelOutput):
+ """
+ Base class for models that have been trained with the Wav2Vec2 loss objective.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
+ Sequence of extracted feature vectors of the last convolutional layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ extract_features: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class XVectorOutput(ModelOutput):
+ """
+ Output type of [`Wav2Vec2ForXVector`].
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
+ Classification hidden states before AMSoftmax.
+ embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
+ Utterance embeddings used for vector similarity-based retrieval.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ embeddings: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BackboneOutput(ModelOutput):
+ """
+ Base class for outputs of backbones.
+
+ Args:
+ feature_maps (`tuple(torch.FloatTensor)` of shape `(batch_size, num_channels, height, width)`):
+ Feature maps of the stages.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, num_channels, height, width)`,
+ depending on the backbone.
+
+ Hidden-states of the model at the output of each stage plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Only applicable if the backbone uses attention.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ feature_maps: Optional[tuple[torch.FloatTensor]] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithPoolingAndProjection(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
+ Last layer hidden-state of the first token of the sequence (classification token) after further processing
+ through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
+ the classification token after processing through a linear layer and a tanh activation function. The linear
+ layer weights are trained from the next sentence prediction (classification) objective during pretraining.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ projection_state (`tuple(torch.FloatTensor)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` of shape `(batch_size,config.project_dim)`.
+
+ Text embeddings before the projection layer, used to mimic the last hidden state of the teacher encoder.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ pooler_output: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ projection_state: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class Seq2SeqSpectrogramOutput(ModelOutput):
+ """
+ Base class for sequence-to-sequence spectrogram outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Spectrogram generation loss.
+ spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`):
+ The predicted spectrogram.
+ past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ spectrogram: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[EncoderDecoderCache] = None
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Seq2SeqTSModelOutput(ModelOutput):
+ """
+ Base class for time series model's encoder outputs that also contains pre-computed hidden states that can speed up
+ sequential decoding.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ loc (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*):
+ Shift values of each time series' context window which is used to give the model inputs of the same
+ magnitude and then used to shift back to the original magnitude.
+ scale (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*):
+ Scaling values of each time series' context window which is used to give the model inputs of the same
+ magnitude and then used to rescale back to the original magnitude.
+ static_features (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*):
+ Static features of each time series' in a batch which are copied to the covariates at inference time.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[EncoderDecoderCache] = None
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ loc: Optional[torch.FloatTensor] = None
+ scale: Optional[torch.FloatTensor] = None
+ static_features: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+class Seq2SeqTSPredictionOutput(ModelOutput):
+ """
+ Base class for time series model's decoder outputs that also contain the loss as well as the parameters of the
+ chosen distribution.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when a `future_values` is provided):
+ Distributional loss.
+ params (`torch.FloatTensor` of shape `(batch_size, num_samples, num_params)`):
+ Parameters of the chosen distribution.
+ past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ loc (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*):
+ Shift values of each time series' context window which is used to give the model inputs of the same
+ magnitude and then used to shift back to the original magnitude.
+ scale (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*):
+ Scaling values of each time series' context window which is used to give the model inputs of the same
+ magnitude and then used to rescale back to the original magnitude.
+ static_features (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*):
+ Static features of each time series' in a batch which are copied to the covariates at inference time.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ params: Optional[tuple[torch.FloatTensor, ...]] = None
+ past_key_values: Optional[EncoderDecoderCache] = None
+ decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ loc: Optional[torch.FloatTensor] = None
+ scale: Optional[torch.FloatTensor] = None
+ static_features: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+class SampleTSPredictionOutput(ModelOutput):
+ """
+ Base class for time series model's predictions outputs that contains the sampled values from the chosen
+ distribution.
+
+ Args:
+ sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length)` or `(batch_size, num_samples, prediction_length, input_size)`):
+ Sampled values from the chosen distribution.
+ """
+
+ sequences: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+class MaskedImageModelingOutput(ModelOutput):
+ """
+ Base class for outputs of masked image completion / in-painting models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
+ Reconstruction loss.
+ reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Reconstructed / completed images.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
+ when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
+ (also called feature maps) of the model at the output of each stage.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when
+ `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+ the self-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ reconstruction: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+ @property
+ def logits(self):
+ warnings.warn(
+ "logits attribute is deprecated and will be removed in version 5 of Transformers."
+ " Please use the reconstruction attribute to retrieve the final output instead.",
+ FutureWarning,
+ )
+ return self.reconstruction
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_rope_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_rope_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0070df6ee17ae08c4ddcd6294379bf207867e03
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_rope_utils.py
@@ -0,0 +1,773 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from functools import wraps
+from typing import Optional
+
+from .configuration_utils import PretrainedConfig
+from .utils import is_torch_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_torch_available():
+ import torch
+
+
+def dynamic_rope_update(rope_forward):
+ """
+ Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
+ (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
+
+ Args:
+ rope_forward (Callable):
+ The forward pass of the RoPE implementation.
+
+ Returns:
+ The decorated forward pass.
+ """
+
+ def longrope_frequency_update(self, position_ids, device):
+ """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
+ seq_len = torch.max(position_ids) + 1
+ if hasattr(self.config, "original_max_position_embeddings"):
+ original_max_position_embeddings = self.config.original_max_position_embeddings
+ else:
+ original_max_position_embeddings = self.config.max_position_embeddings
+ if seq_len > original_max_position_embeddings:
+ if not hasattr(self, "long_inv_freq"):
+ self.long_inv_freq, _ = self.rope_init_fn(
+ self.config, device, seq_len=original_max_position_embeddings + 1
+ )
+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
+ else:
+ # This .to() is needed if the model has been moved to a device after being initialized (because
+ # the buffer is automatically moved, but not the original copy)
+ self.original_inv_freq = self.original_inv_freq.to(device)
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+
+ def dynamic_frequency_update(self, position_ids, device):
+ """
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
+ 1 - growing beyond the cached sequence length (allow scaling)
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
+ """
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_seq_len_cached: # growth
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
+ self.max_seq_len_cached = seq_len
+
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
+ # This .to() is needed if the model has been moved to a device after being initialized (because
+ # the buffer is automatically moved, but not the original copy)
+ self.original_inv_freq = self.original_inv_freq.to(device)
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+ self.max_seq_len_cached = self.original_max_seq_len
+
+ @wraps(rope_forward)
+ def wrapper(self, x, position_ids):
+ if "dynamic" in self.rope_type:
+ dynamic_frequency_update(self, position_ids, device=x.device)
+ elif self.rope_type == "longrope":
+ longrope_frequency_update(self, position_ids, device=x.device)
+ return rope_forward(self, x, position_ids)
+
+ return wrapper
+
+
+def _compute_default_rope_parameters(
+ config: Optional[PretrainedConfig] = None,
+ device: Optional["torch.device"] = None,
+ seq_len: Optional[int] = None,
+) -> tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies according to the original RoPE implementation
+ Args:
+ config ([`~transformers.PretrainedConfig`]):
+ The model configuration. This function assumes that the config will provide at least the following
+ properties:
+
+ * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived.
+ * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
+ * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
+
+ Additionally, this function will make use of the following properties if they are found in the config:
+
+ * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be
+ derived as hidden_size // num_attention_heads.
+ * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for
+ the first fraction of the head_dim. Defaults to 1.0.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length. Unused for this type of RoPE.
+
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
+ """
+ base = config.rope_theta
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+ dim = int(head_dim * partial_rotary_factor)
+
+ attention_factor = 1.0 # Unused in this type of RoPE
+
+ # Compute the inverse frequencies
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
+ return inv_freq, attention_factor
+
+
+def _compute_linear_scaling_rope_parameters(
+ config: Optional[PretrainedConfig] = None,
+ device: Optional["torch.device"] = None,
+ seq_len: Optional[int] = None,
+) -> tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
+ Args:
+ config ([`~transformers.PretrainedConfig`]):
+ The model configuration. This function assumes that the config will provide at least the following
+ properties:
+
+ * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived.
+ * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
+ * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
+
+ Additionally, this function will make use of the following properties if they are found in the config:
+
+ * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be
+ derived as hidden_size // num_attention_heads.
+ * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for
+ the first fraction of the head_dim. Defaults to 1.0.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length. Unused for this type of RoPE.
+
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
+ """
+ factor = config.rope_scaling["factor"]
+
+ # Gets the default RoPE parameters
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len)
+
+ # Then applies linear scaling to the frequencies.
+ # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
+ # applying scaling to the inverse frequencies is equivalent.
+ inv_freq /= factor
+ return inv_freq, attention_factor
+
+
+def _compute_dynamic_ntk_parameters(
+ config: Optional[PretrainedConfig] = None,
+ device: Optional["torch.device"] = None,
+ seq_len: Optional[int] = None,
+) -> tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
+
+ Args:
+ config ([`~transformers.PretrainedConfig`]):
+ The model configuration. This function assumes that the config will provide at least the following
+ properties:
+
+ * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived.
+ * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
+ * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
+ * max_position_embeddings (`int`): The default sequence length used to update the dynamic RoPE at
+ inference time
+ * rope_scaling (`dict[str, float]`): The standard RoPE scaling parameters, from which `factor`
+ will be accessed. The value of `factor` is used to determine the new base frequency, along with the
+ current sequence length (seq_len), the maximum positional embeddings (max_position_embeddings), and the
+ computed dimensionality (dim) of the rotary embeddings. If seq_len <= max_position_embeddings, this
+ factor has no effect. If seq_len <= max_position_embeddings, this factor effectively stretches the
+ context window using an exponent derived from `dim`.
+
+ Additionally, this function will make use of the following properties if they are found in the config:
+
+ * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be
+ derived as hidden_size // num_attention_heads.
+ * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for
+ the first fraction of the head_dim. Defaults to 1.0.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length, used to update the dynamic RoPE at inference time. If `None` or shorter than
+ max_position_embeddings, this value will be overridden by max_position_embeddings.
+
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
+ """
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
+ base = config.rope_theta
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ dim = int(head_dim * partial_rotary_factor)
+ max_position_embeddings = config.max_position_embeddings
+ factor = config.rope_scaling["factor"]
+
+ attention_factor = 1.0 # Unused in this type of RoPE
+
+ # seq_len: default to max_position_embeddings, e.g. at init time
+ if seq_len is None:
+ seq_len = max_position_embeddings
+ elif isinstance(seq_len, torch.Tensor):
+ seq_len = torch.maximum(
+ seq_len,
+ torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
+ )
+ else:
+ seq_len = max(seq_len, max_position_embeddings)
+
+ # Compute the inverse frequencies
+ base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
+ return inv_freq, attention_factor
+
+
+def _compute_yarn_parameters(
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None
+) -> tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies with NTK scaling. Please refer to the
+ [original paper](https://huggingface.co/papers/2309.00071)
+
+ Args:
+ config ([`~transformers.PretrainedConfig`]):
+ The model configuration. This function assumes that the config will provide at least the following
+ properties:
+
+ * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived.
+ * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
+ * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
+ * max_position_embeddings (`int`): The maximum length of the positional embeddings.
+ * rope_scaling (`dict[str, float | int]`): The standard RoPE scaling parameters, from which the following
+ keys will be accessed:
+ * `attention_factor` (`float`, *optional*): The scaling factor to be applied to the computed cos/sin.
+ If None, the value is inferred from `factor`, `mscale`, and `mscale_all_dim` as avaialble.
+ * `beta_fast` (`float`, *optional*, defaults to 32): Parameter to set the boundary for extrapolation
+ (only) in the linear ramp function.
+ * `beta_slow` (`float`, *optional*, defaults to 1): Parameter to set the boundary for interpolation
+ (only) in the linear ramp function.
+ * `factor` (`float`, *optional*): The scaling factor applied when interpolating the position IDs to
+ extend the possible context length. Additionally, if `attention_factor` is None, the log of this
+ value is used to compute a value for `attention_factor`, possibly in conjunciton with `mscale` and
+ `mscale_all_dim`, if provided.
+ * `mscale` (`float`, *optional*): If `attention_factor` is None and both `mscale` and
+ `mscale_all_dim` are provided, `mscale` acts scalar augmenting `log(factor)` when computing the
+ numerator for the inferred value of `attention_factor`. If not provided, `attention_factor` will be
+ calculated based on `factor` only.
+ * `mscale_all_dim` (`float`, *optional*): If `attention_factor` is None and both `mscale` and
+ `mscale_all_dim` are provided, `mscale_all_dim` acts scalar augmenting `log(factor)` when computing
+ the denominator for the inferred value of `attention_factor`. If not provided, `attention_factor`
+ will be calculated based on `factor` only.
+ * `original_max_position_embeddings` (`int`, *optional*): The original max position embeddings used
+ during pretraining. If not provided, the function falls back to `max_position_embeddings`.
+ * `truncate` (`bool`, *optional*): Whether to truncate the correction range.
+
+ Additionally, this function will make use of the following properties if they are found in the config:
+
+ * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be
+ derived as hidden_size // num_attention_heads.
+ * partial_rotary_factor (`float`, *optional*, defaults to 1.0): If less than 1.0, inverse frequencies
+ will be returned for the first fraction of the head_dim.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length. Unused for this type of RoPE.
+
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin.
+ """
+
+ base = config.rope_theta
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ dim = int(head_dim * partial_rotary_factor)
+ factor = config.rope_scaling["factor"]
+ attention_factor = config.rope_scaling.get("attention_factor")
+ mscale = config.rope_scaling.get("mscale")
+ mscale_all_dim = config.rope_scaling.get("mscale_all_dim")
+ original_max_position_embeddings = (
+ config.rope_scaling.get("original_max_position_embeddings") or config.max_position_embeddings
+ )
+
+ def get_mscale(scale, mscale=1):
+ if scale <= 1:
+ return 1.0
+ return 0.1 * mscale * math.log(scale) + 1.0
+
+ # Sets the attention factor as suggested in the paper
+ if attention_factor is None:
+ if mscale and mscale_all_dim:
+ attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
+ else:
+ attention_factor = get_mscale(factor)
+
+ # Optional config options
+ # beta_fast/beta_slow: as suggested in the paper, default to 32 and 1 respectively
+ beta_fast = config.rope_scaling.get("beta_fast") or 32
+ beta_slow = config.rope_scaling.get("beta_slow") or 1
+
+ # Compute the inverse frequencies
+ def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
+ """Inverse dimension formula to find the dimension based on the number of rotations"""
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
+
+ def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, truncate):
+ """Find dimension range bounds based on rotations"""
+ low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
+ high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
+ if truncate:
+ low = math.floor(low)
+ high = math.ceil(high)
+ return max(low, 0), min(high, dim - 1)
+
+ def linear_ramp_factor(min, max, dim):
+ if min == max:
+ max += 0.001 # Prevent singularity
+
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
+ ramp_func = torch.clamp(linear_func, 0, 1)
+ return ramp_func
+
+ # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
+ # to expand the possible context length. In other words, interpolation = apply scaling factor.
+ pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim)
+ inv_freq_extrapolation = 1.0 / pos_freqs
+ inv_freq_interpolation = 1.0 / (factor * pos_freqs)
+
+ truncate = config.rope_scaling.get("truncate", True)
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate)
+
+ # Get n-dimensional rotational scaling corrected for extrapolation
+ inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
+ inv_freq = (
+ inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
+ + inv_freq_extrapolation * inv_freq_extrapolation_factor
+ )
+ return inv_freq, attention_factor
+
+
+def _compute_longrope_parameters(
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None
+) -> tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies with LongRoPE scaling. Please refer to the
+ [original implementation](https://github.com/microsoft/LongRoPE)
+
+ Args:
+ config ([`~transformers.PretrainedConfig`]):
+ The model configuration. This function assumes that the config will provide at least the following
+ properties:
+
+ * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived.
+ * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
+ * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
+ * max_position_embeddings (`int`): The maximum length of the positional embeddings.
+ * original_max_position_embeddings (`int`, *optional*): The original max position embeddings used during
+ pretraining. If not provided, defaults to `max_position_embeddings`.
+ * rope_scaling (`dict[str, float]`): The standard RoPE scaling parameters, from which the following keys
+ will be accessed:
+ * `attention_factor` (`float`, *optional*): The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, inferred from
+ the value of `factor`.
+ * `factor` (`float`, *optional*): The scaling factor to apply to the RoPE embeddings. If both
+ `max_position_embeddings` and `original_max_position_embeddings` are provided, this value will be
+ overridden s the ratio between those values.
+ * `long_factor` (`float`, *optional*): The scale factor applied when computing the inverse
+ frequencies if `seq_len` is provided and greater than `original_max_position_embeddings`.
+ * `short_factor` (`float`, *optional*): The scale factor applied when computing the inverse
+ frequencies if `seq_len` is None or less-than-or-equal-to `original_max_position_embeddings`.
+
+ Additionally, this function will make use of the following properties if they are found in the config:
+
+ * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be
+ derived as hidden_size // num_attention_heads.
+ * partial_rotary_factor (`float`, *optional*, defaults to 1.0): If less than 1.0, inverse frequencies
+ will be returned for the first fraction of the head_dim.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length.
+
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin.
+ """
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
+ base = config.rope_theta
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ dim = int(head_dim * partial_rotary_factor)
+ long_factor = config.rope_scaling["long_factor"]
+ short_factor = config.rope_scaling["short_factor"]
+ factor = config.rope_scaling.get("factor")
+ attention_factor = config.rope_scaling.get("attention_factor")
+
+ # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
+ # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
+ # values to compute the default attention scaling factor, instead of using `factor`.
+ if original_max_position_embeddings := getattr(config, "original_max_position_embeddings", None):
+ factor = config.max_position_embeddings / original_max_position_embeddings
+ else:
+ original_max_position_embeddings = config.max_position_embeddings
+
+ # Sets the attention factor as suggested in the paper
+ if attention_factor is None:
+ if factor <= 1.0:
+ attention_factor = 1.0
+ else:
+ attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings))
+
+ # Compute the inverse frequencies -- scaled based on the target sequence length
+ if seq_len and seq_len > original_max_position_embeddings:
+ ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
+ else:
+ ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
+ inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
+ inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
+
+ return inv_freq, attention_factor
+
+
+def _compute_llama3_parameters(
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None
+) -> tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies for llama 3.1.
+
+ Args:
+ config ([`~transformers.PretrainedConfig`]):
+ The model configuration. This function assumes that the config will provide at least the following
+ properties:
+
+ * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived.
+ * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
+ * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
+ * rope_scaling (`dict[str, float | int]`): The standard RoPE scaling parameters, from which the following
+ keys will be accessed:
+ * `factor` (`float`, *optional*): The scaling factor applied to the inverse frequencies when 1) the
+ wavelength is greater than `low_freq_wavelen` prior to smoothing, and 2) to all inverse frequencies
+ during smoothing.
+ * `high_freq_factor` (`float`): The scale factor used to compute `high_freq_wavelen` and
+ the value for the denominator of the smoothing factor prior to the `low_freq_factor` shift.
+ * `low_freq_factor` (`float`): The scale factor used to compute `low_freq_wavelen` and
+ the shift applied to the numerator and denominator of the smoothing factor.
+ frequencies if `seq_len` is None or less-than-or-equal-to `original_max_position_embeddings`.
+ * `original_max_position_embeddings` (`int`): The original max position embeddings used
+ during pretraining. If not provided, the function falls back to `max_position_embeddings`.
+
+ Additionally, this function will make use of the following properties if they are found in the config:
+
+ * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be
+ derived as hidden_size // num_attention_heads.
+ * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for
+ the first fraction of the head_dim. Defaults to 1.0.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length. Unused for this type of RoPE.
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin.
+ """
+ # Gets the default RoPE parameters
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len)
+
+ factor = config.rope_scaling["factor"] # `8` in the original implementation
+ low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
+ high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
+ old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
+
+ low_freq_wavelen = old_context_len / low_freq_factor
+ high_freq_wavelen = old_context_len / high_freq_factor
+
+ wavelen = 2 * math.pi / inv_freq
+ # wavelen < high_freq_wavelen: do nothing
+ # wavelen > low_freq_wavelen: divide by factor
+ inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
+ # otherwise: interpolate between the two, using a smooth factor
+ smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
+ smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
+ is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
+ inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
+
+ return inv_freq_llama, attention_factor
+
+
+# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
+# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
+# parameterizations, as long as the callable has the same signature.
+ROPE_INIT_FUNCTIONS = {
+ "default": _compute_default_rope_parameters,
+ "linear": _compute_linear_scaling_rope_parameters,
+ "dynamic": _compute_dynamic_ntk_parameters,
+ "yarn": _compute_yarn_parameters,
+ "longrope": _compute_longrope_parameters,
+ "llama3": _compute_llama3_parameters,
+}
+
+
+def _check_received_keys(
+ rope_type: str,
+ received_keys: set,
+ required_keys: set,
+ optional_keys: Optional[set] = None,
+ ignore_keys: Optional[set] = None,
+):
+ """Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
+ # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
+ if "type" in received_keys:
+ received_keys -= {"type"}
+ required_keys.add("rope_type")
+
+ # Some models need to store model-specific keys, and we don't want to throw warning at them
+ if ignore_keys is not None:
+ received_keys -= ignore_keys
+
+ missing_keys = required_keys - received_keys
+ if missing_keys:
+ raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
+
+ if optional_keys is not None:
+ unused_keys = received_keys - required_keys - optional_keys
+ else:
+ unused_keys = received_keys - required_keys
+ if unused_keys:
+ logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
+
+
+def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
+ rope_scaling = config.rope_scaling
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
+ required_keys = {"rope_type"}
+ received_keys = set(rope_scaling.keys())
+ _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
+
+
+def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
+ rope_scaling = config.rope_scaling
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
+ required_keys = {"rope_type", "factor"}
+ received_keys = set(rope_scaling.keys())
+ _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
+
+ factor = rope_scaling["factor"]
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
+
+
+def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
+ rope_scaling = config.rope_scaling
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
+ required_keys = {"rope_type", "factor"}
+ # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
+ optional_keys = {"original_max_position_embeddings"}
+ received_keys = set(rope_scaling.keys())
+ _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
+
+ factor = rope_scaling["factor"]
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
+
+
+def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
+ rope_scaling = config.rope_scaling
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
+ required_keys = {"rope_type", "factor"}
+ optional_keys = {
+ "attention_factor",
+ "beta_fast",
+ "beta_slow",
+ "original_max_position_embeddings",
+ "mscale",
+ "mscale_all_dim",
+ "truncate",
+ }
+ received_keys = set(rope_scaling.keys())
+ _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
+
+ factor = rope_scaling["factor"]
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
+
+ attention_factor = rope_scaling.get("attention_factor")
+ if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
+ logger.warning(
+ f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
+ )
+ beta_fast = rope_scaling.get("beta_fast")
+ if beta_fast is not None and not isinstance(beta_fast, float):
+ logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
+ beta_slow = rope_scaling.get("beta_slow")
+ if beta_slow is not None and not isinstance(beta_slow, float):
+ logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
+
+ if (beta_fast or 32) < (beta_slow or 1):
+ logger.warning(
+ f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
+ f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
+ )
+
+ # Models should set `config.rope_scaling["original_max_position_embeddings"]` to their original (pre-yarn) context
+ # length, with `config.max_position_embeddings` corresponding to their post-yarn context length.
+ # However, for BC purposes, we allow the former to be unset.
+ original_max_position_embeddings = config.rope_scaling.get("original_max_position_embeddings")
+ if original_max_position_embeddings is not None:
+ # Double-check: `factor` should be the ratio between the pre-yarn and post-yarn context lengths.
+ implicit_factor = config.max_position_embeddings / original_max_position_embeddings
+ if implicit_factor != factor:
+ logger.warning_once(
+ f"The explicitly set RoPE scaling factor (config.rope_scaling['factor'] = {factor}) does not match "
+ "the ratio implicitly set by other parameters (implicit factor = "
+ "post-yarn context length / pre-yarn context length = "
+ "config.max_position_embeddings / config.rope_scaling['original_max_position_embeddings'] = "
+ f"{implicit_factor}). Using the explicit factor ({factor}) in YaRN. This may cause unexpected "
+ "behaviour in model usage, please correct the 'max_position_embeddings' fields in the model config."
+ )
+ # No `config.rope_scaling["original_max_position_embeddings"]`. Is `config.max_position_embeddings` the
+ # pre-yarn or the post-yarn context length?
+ # BC: we assume it is the pre-yarn context length.
+ else:
+ logger.warning_once(
+ "config.rope_scaling['original_max_position_embeddings'], the pre-yarn context length, is unset. We will "
+ "**assume** config.max_position_embeddings holds the pre-yarn context length. Some use cases may expect "
+ "config.max_position_embeddings to hold the post-yarn context length (pre-yarn context length * "
+ "factor) -- we recommend updating both fields for optimal downstream model usage."
+ )
+
+
+def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
+ rope_scaling = config.rope_scaling
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
+ required_keys = {"rope_type", "short_factor", "long_factor"}
+ # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
+ optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
+ received_keys = set(rope_scaling.keys())
+ _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
+
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ dim = int(head_dim * partial_rotary_factor)
+
+ short_factor = rope_scaling.get("short_factor")
+ if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
+ logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
+ if len(short_factor) != dim // 2:
+ logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
+
+ long_factor = rope_scaling.get("long_factor")
+ if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
+ logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
+ if len(long_factor) != dim // 2:
+ logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
+
+ # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
+ # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
+ # unique to longrope (= undesirable)
+ if hasattr(config, "original_max_position_embeddings"):
+ logger.warning_once(
+ "This model has set a `original_max_position_embeddings` field, to be used together with "
+ "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`"
+ "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
+ "as it is compatible with most model architectures."
+ )
+ else:
+ factor = rope_scaling.get("factor")
+ if factor is None:
+ logger.warning("Missing required keys in `rope_scaling`: 'factor'")
+ elif not isinstance(factor, float) or factor < 1.0:
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
+
+ attention_factor = rope_scaling.get("attention_factor")
+ if attention_factor is not None:
+ if not isinstance(attention_factor, float) or attention_factor < 0.0:
+ logger.warning(
+ f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
+ )
+
+
+def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
+ rope_scaling = config.rope_scaling
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
+ required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
+ received_keys = set(rope_scaling.keys())
+ _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
+
+ factor = rope_scaling["factor"]
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
+
+ low_freq_factor = rope_scaling["low_freq_factor"]
+ high_freq_factor = rope_scaling["high_freq_factor"]
+ if low_freq_factor is None or not isinstance(low_freq_factor, float):
+ logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
+ if high_freq_factor is None or not isinstance(high_freq_factor, float):
+ logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
+ if high_freq_factor <= low_freq_factor:
+ logger.warning(
+ "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
+ f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
+ )
+
+ original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
+ if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
+ logger.warning(
+ "`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
+ f"{original_max_position_embeddings}"
+ )
+ if original_max_position_embeddings >= config.max_position_embeddings:
+ logger.warning(
+ "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
+ f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
+ )
+
+
+# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
+ROPE_VALIDATION_FUNCTIONS = {
+ "default": _validate_default_rope_parameters,
+ "linear": _validate_linear_scaling_rope_parameters,
+ "dynamic": _validate_dynamic_scaling_rope_parameters,
+ "yarn": _validate_yarn_parameters,
+ "longrope": _validate_longrope_parameters,
+ "llama3": _validate_llama3_parameters,
+}
+
+
+def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):
+ """
+ Validate the RoPE config arguments, given a `PretrainedConfig` object
+ """
+ rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig`
+ if rope_scaling is None:
+ return
+
+ # BC: "rope_type" was originally "type"
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
+ validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
+ if validation_fn is not None:
+ validation_fn(config, ignore_keys=ignore_keys)
+ else:
+ logger.warning(
+ f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
+ )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_tf_outputs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_tf_outputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7491b67f9aebb93b95632a5a7db2fd85cf5b4c7
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_tf_outputs.py
@@ -0,0 +1,990 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import warnings
+from dataclasses import dataclass
+
+import tensorflow as tf
+
+from .utils import ModelOutput
+
+
+@dataclass
+class TFBaseModelOutput(ModelOutput):
+ """
+ Base class for model's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(tf.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFBaseModelOutputWithNoAttention(ModelOutput):
+ """
+ Base class for model's outputs, with potential hidden states.
+
+ Args:
+ last_hidden_state (`tf.Tensor` shape `(batch_size, num_channels, height, width)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
+ the output of each layer) of shape `(batch_size, num_channels, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ """
+
+ last_hidden_state: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor, ...] | None = None
+
+
+@dataclass
+class TFBaseModelOutputWithPooling(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
+ Last layer hidden-state of the first token of the sequence (classification token) further processed by a
+ Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
+ prediction (classification) objective during pretraining.
+
+ This output is usually *not* a good summary of the semantic content of the input, you're often better with
+ averaging or pooling the sequence of hidden-states for the whole input sequence.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: tf.Tensor | None = None
+ pooler_output: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFBaseModelOutputWithPoolingAndNoAttention(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
+ Last layer hidden-state after a pooling operation on the spatial dimensions.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
+ the output of each layer) of shape `(batch_size, num_channels, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ """
+
+ last_hidden_state: tf.Tensor | None = None
+ pooler_output: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor, ...] | None = None
+
+
+@dataclass
+class TFBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
+ Last layer hidden-state of the first token of the sequence (classification token) further processed by a
+ Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
+ prediction (classification) objective during pretraining.
+
+ This output is usually *not* a good summary of the semantic content of the input, you're often better with
+ averaging or pooling the sequence of hidden-states for the whole input sequence.
+ past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+ sequence_length, embed_size_per_head)`).
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ """
+
+ last_hidden_state: tf.Tensor | None = None
+ pooler_output: tf.Tensor | None = None
+ past_key_values: list[tf.Tensor] | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+ cross_attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFBaseModelOutputWithPast(ModelOutput):
+ """
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+ Args:
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+ sequence_length, embed_size_per_head)`).
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: tf.Tensor | None = None
+ past_key_values: list[tf.Tensor] | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFBaseModelOutputWithCrossAttentions(ModelOutput):
+ """
+ Base class for model's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(tf.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ """
+
+ last_hidden_state: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+ cross_attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
+ """
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+ Args:
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+ sequence_length, embed_size_per_head)`).
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(tf.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ """
+
+ last_hidden_state: tf.Tensor | None = None
+ past_key_values: list[tf.Tensor] | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+ cross_attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFSeq2SeqModelOutput(ModelOutput):
+ """
+ Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
+ decoding.
+
+ Args:
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+ sequence_length, embed_size_per_head)`).
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
+ used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ """
+
+ last_hidden_state: tf.Tensor | None = None
+ past_key_values: list[tf.Tensor] | None = None
+ decoder_hidden_states: tuple[tf.Tensor] | None = None
+ decoder_attentions: tuple[tf.Tensor] | None = None
+ cross_attentions: tuple[tf.Tensor] | None = None
+ encoder_last_hidden_state: tf.Tensor | None = None
+ encoder_hidden_states: tuple[tf.Tensor] | None = None
+ encoder_attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFCausalLMOutput(ModelOutput):
+ """
+ Base class for causal language model (or autoregressive) outputs.
+
+ Args:
+ loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFCausalLMOutputWithPast(ModelOutput):
+ """
+ Base class for causal language model (or autoregressive) outputs.
+
+ Args:
+ loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+ sequence_length, embed_size_per_head)`).
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: tf.Tensor | None = None
+ past_key_values: list[tf.Tensor] | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFCausalLMOutputWithCrossAttentions(ModelOutput):
+ """
+ Base class for causal language model (or autoregressive) outputs.
+
+ Args:
+ loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+ sequence_length, embed_size_per_head)`).
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: tf.Tensor | None = None
+ past_key_values: list[tf.Tensor] | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+ cross_attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFMaskedLMOutput(ModelOutput):
+ """
+ Base class for masked language models outputs.
+
+ Args:
+ loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided):
+ Masked language modeling (MLM) loss.
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFSeq2SeqLMOutput(ModelOutput):
+ """
+ Base class for sequence-to-sequence language models outputs.
+
+ Args:
+ loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided):
+ Language modeling loss.
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+ sequence_length, embed_size_per_head)`).
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
+ used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: tf.Tensor | None = None
+ past_key_values: list[tf.Tensor] | None = None
+ decoder_hidden_states: tuple[tf.Tensor] | None = None
+ decoder_attentions: tuple[tf.Tensor] | None = None
+ cross_attentions: tuple[tf.Tensor] | None = None
+ encoder_last_hidden_state: tf.Tensor | None = None
+ encoder_hidden_states: tuple[tf.Tensor] | None = None
+ encoder_attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFNextSentencePredictorOutput(ModelOutput):
+ """
+ Base class for outputs of models predicting if two sentences are consecutive or not.
+
+ Args:
+ loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `next_sentence_label` is provided):
+ Next sentence prediction loss.
+ logits (`tf.Tensor` of shape `(batch_size, 2)`):
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
+ before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFSequenceClassifierOutput(ModelOutput):
+ """
+ Base class for outputs of sentence classification models.
+
+ Args:
+ loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
+ """
+ Base class for outputs of sequence-to-sequence sentence classification models.
+
+ Args:
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `label` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+ sequence_length, embed_size_per_head)`).
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
+ used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`
+ encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: tf.Tensor | None = None
+ past_key_values: list[tf.Tensor] | None = None
+ decoder_hidden_states: tuple[tf.Tensor] | None = None
+ decoder_attentions: tuple[tf.Tensor] | None = None
+ cross_attentions: tuple[tf.Tensor] | None = None
+ encoder_last_hidden_state: tf.Tensor | None = None
+ encoder_hidden_states: tuple[tf.Tensor] | None = None
+ encoder_attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFSemanticSegmenterOutput(ModelOutput):
+ """
+ Base class for outputs of semantic segmentation models.
+
+ Args:
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
+ Classification scores for each pixel.
+
+
+
+ The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
+ to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
+ original image size as post-processing. You should always check your logits shape and resize as needed.
+
+
+
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
+ the output of each layer) of shape `(batch_size, patch_size, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFSemanticSegmenterOutputWithNoAttention(ModelOutput):
+ """
+ Base class for outputs of semantic segmentation models that do not output attention scores.
+
+ Args:
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
+ Classification scores for each pixel.
+
+
+
+ The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
+ to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
+ original image size as post-processing. You should always check your logits shape and resize as needed.
+
+
+
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
+ the output of each layer) of shape `(batch_size, patch_size, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFImageClassifierOutput(ModelOutput):
+ """
+ Base class for outputs of image classification models.
+
+ Args:
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
+ the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called
+ feature maps) of the model at the output of each stage.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFMultipleChoiceModelOutput(ModelOutput):
+ """
+ Base class for outputs of multiple choice models.
+
+ Args:
+ loss (`tf.Tensor` of shape *(batch_size, )*, *optional*, returned when `labels` is provided):
+ Classification loss.
+ logits (`tf.Tensor` of shape `(batch_size, num_choices)`):
+ *num_choices* is the second dimension of the input tensors. (see *input_ids* above).
+
+ Classification scores (before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFTokenClassifierOutput(ModelOutput):
+ """
+ Base class for outputs of token classification models.
+
+ Args:
+ loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of unmasked labels, returned when `labels` is provided) :
+ Classification loss.
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`):
+ Classification scores (before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFQuestionAnsweringModelOutput(ModelOutput):
+ """
+ Base class for outputs of question answering models.
+
+ Args:
+ loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `start_positions` and `end_positions` are provided):
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
+ start_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Span-start scores (before SoftMax).
+ end_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Span-end scores (before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: tf.Tensor | None = None
+ start_logits: tf.Tensor | None = None
+ end_logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
+ """
+ Base class for outputs of sequence-to-sequence question answering models.
+
+ Args:
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
+ start_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Span-start scores (before SoftMax).
+ end_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Span-end scores (before SoftMax).
+ past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+ sequence_length, embed_size_per_head)`).
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
+ used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ """
+
+ loss: tf.Tensor | None = None
+ start_logits: tf.Tensor | None = None
+ end_logits: tf.Tensor | None = None
+ past_key_values: list[tf.Tensor] | None = None
+ decoder_hidden_states: tuple[tf.Tensor] | None = None
+ decoder_attentions: tuple[tf.Tensor] | None = None
+ encoder_last_hidden_state: tf.Tensor | None = None
+ encoder_hidden_states: tuple[tf.Tensor] | None = None
+ encoder_attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFSequenceClassifierOutputWithPast(ModelOutput):
+ """
+ Base class for outputs of sentence classification models.
+
+ Args:
+ loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+ sequence_length, embed_size_per_head)`).
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: tf.Tensor | None = None
+ past_key_values: list[tf.Tensor] | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFImageClassifierOutputWithNoAttention(ModelOutput):
+ """
+ Base class for outputs of image classification models.
+
+ Args:
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
+ the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also called
+ feature maps) of the model at the output of each stage.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor, ...] | None = None
+
+
+@dataclass
+class TFMaskedImageModelingOutput(ModelOutput):
+ """
+ Base class for outputs of masked image completion / in-painting models.
+
+ Args:
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
+ Reconstruction loss.
+ reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
+ Reconstructed / completed images.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when
+ `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
+ the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called
+ feature maps) of the model at the output of each stage.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when
+ `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`.
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: tf.Tensor | None = None
+ reconstruction: tf.Tensor | None = None
+ hidden_states: tuple[tf.Tensor] | None = None
+ attentions: tuple[tf.Tensor] | None = None
+
+ @property
+ def logits(self):
+ warnings.warn(
+ "logits attribute is deprecated and will be removed in version 5 of Transformers."
+ " Please use the reconstruction attribute to retrieve the final output instead.",
+ FutureWarning,
+ )
+ return self.reconstruction
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_tf_pytorch_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_tf_pytorch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f688af7be36439311465d019de36c20c6aaae77
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_tf_pytorch_utils.py
@@ -0,0 +1,676 @@
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch - TF 2.0 general utilities."""
+
+import os
+import re
+
+import numpy
+
+from .utils import (
+ ExplicitEnum,
+ check_torch_load_is_safe,
+ expand_dims,
+ is_numpy_array,
+ is_safetensors_available,
+ is_torch_tensor,
+ logging,
+ reshape,
+ squeeze,
+ tensor_size,
+)
+from .utils import transpose as transpose_func
+
+
+if is_safetensors_available():
+ from safetensors import safe_open
+
+
+logger = logging.get_logger(__name__)
+
+
+class TransposeType(ExplicitEnum):
+ """
+ Possible ...
+ """
+
+ NO = "no"
+ SIMPLE = "simple"
+ CONV1D = "conv1d"
+ CONV2D = "conv2d"
+
+
+def convert_tf_weight_name_to_pt_weight_name(
+ tf_name, start_prefix_to_remove="", tf_weight_shape=None, name_scope=None
+):
+ """
+ Convert a TF 2.0 model variable name in a pytorch model weight name.
+
+ Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
+
+ - '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
+ - '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
+
+ return tuple with:
+
+ - pytorch model weight name
+ - transpose: `TransposeType` member indicating whether and how TF2.0 and PyTorch weights matrices should be
+ transposed with regards to each other
+ """
+ if name_scope is not None:
+ if not tf_name.startswith(name_scope) and "final_logits_bias" not in tf_name:
+ raise ValueError(
+ f"Weight name {tf_name} does not start with name_scope {name_scope}. This is an internal error "
+ "in Transformers, so (unless you were doing something really evil) please open an issue to report it!"
+ )
+ tf_name = tf_name[len(name_scope) :]
+ tf_name = tf_name.lstrip("/")
+ tf_name = tf_name.replace(":0", "") # device ids
+ if (len(tf_name) > 2048 and "___" in tf_name) or tf_name.count("___") > 10:
+ # ReDOS check
+ raise ValueError("TF variable name is too long or contains too many ___ separators: " + tf_name)
+ tf_name = re.sub(
+ r"/[^/]*___([^/]*)/", r"/\1/", tf_name
+ ) # '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
+ tf_name = tf_name.replace(
+ "_._", "/"
+ ) # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
+ tf_name = re.sub(r"//+", "/", tf_name) # Remove empty levels at the end
+ tf_name = tf_name.split("/") # Convert from TF2.0 '/' separators to PyTorch '.' separators
+ # Some weights have a single name without "/" such as final_logits_bias in BART
+ if len(tf_name) > 1:
+ tf_name = tf_name[1:] # Remove level zero
+
+ tf_weight_shape = list(tf_weight_shape)
+
+ # When should we transpose the weights
+ if tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 4:
+ transpose = TransposeType.CONV2D
+ elif tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 3:
+ transpose = TransposeType.CONV1D
+ elif bool(
+ tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"]
+ or "emb_projs" in tf_name
+ or "out_projs" in tf_name
+ ):
+ transpose = TransposeType.SIMPLE
+ else:
+ transpose = TransposeType.NO
+
+ # Convert standard TF2.0 names in PyTorch names
+ if tf_name[-1] == "kernel" or tf_name[-1] == "embeddings" or tf_name[-1] == "gamma":
+ tf_name[-1] = "weight"
+ if tf_name[-1] == "beta":
+ tf_name[-1] = "bias"
+
+ # The SeparableConv1D TF layer contains two weights that are translated to PyTorch Conv1D here
+ if tf_name[-1] == "pointwise_kernel" or tf_name[-1] == "depthwise_kernel":
+ tf_name[-1] = tf_name[-1].replace("_kernel", ".weight")
+
+ # Remove prefix if needed
+ tf_name = ".".join(tf_name)
+ if start_prefix_to_remove:
+ tf_name = tf_name.replace(start_prefix_to_remove, "", 1)
+
+ return tf_name, transpose
+
+
+def apply_transpose(transpose: TransposeType, weight, match_shape=None, pt_to_tf=True):
+ """
+ Apply a transpose to some weight then tries to reshape the weight to the same shape as a given shape, all in a
+ framework agnostic way.
+ """
+ if transpose is TransposeType.CONV2D:
+ # Conv2D weight:
+ # PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
+ # -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
+ axes = (2, 3, 1, 0) if pt_to_tf else (3, 2, 0, 1)
+ weight = transpose_func(weight, axes=axes)
+ elif transpose is TransposeType.CONV1D:
+ # Conv1D weight:
+ # PT: (num_out_channel, num_in_channel, kernel)
+ # -> TF: (kernel, num_in_channel, num_out_channel)
+ weight = transpose_func(weight, axes=(2, 1, 0))
+ elif transpose is TransposeType.SIMPLE:
+ weight = transpose_func(weight)
+
+ if match_shape is None:
+ return weight
+
+ if len(match_shape) < len(weight.shape):
+ weight = squeeze(weight)
+ elif len(match_shape) > len(weight.shape):
+ weight = expand_dims(weight, axis=0)
+
+ if list(match_shape) != list(weight.shape):
+ try:
+ weight = reshape(weight, match_shape)
+ except AssertionError as e:
+ e.args += (match_shape, match_shape)
+ raise e
+
+ return weight
+
+
+#####################
+# PyTorch => TF 2.0 #
+#####################
+
+
+def load_pytorch_checkpoint_in_tf2_model(
+ tf_model,
+ pytorch_checkpoint_path,
+ tf_inputs=None,
+ allow_missing_keys=False,
+ output_loading_info=False,
+ _prefix=None,
+ tf_to_pt_weight_rename=None,
+):
+ """Load pytorch checkpoints in a TF 2.0 model"""
+ try:
+ import tensorflow as tf # noqa: F401
+ import torch # noqa: F401
+ from safetensors.torch import load_file as safe_load_file # noqa: F401
+ except ImportError:
+ logger.error(
+ "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
+ "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+
+ # Treats a single file as a collection of shards with 1 shard.
+ if isinstance(pytorch_checkpoint_path, str):
+ pytorch_checkpoint_path = [pytorch_checkpoint_path]
+
+ # Loads all shards into a single state dictionary
+ pt_state_dict = {}
+ for path in pytorch_checkpoint_path:
+ pt_path = os.path.abspath(path)
+ logger.info(f"Loading PyTorch weights from {pt_path}")
+ if pt_path.endswith(".safetensors"):
+ state_dict = safe_load_file(pt_path)
+ else:
+ check_torch_load_is_safe()
+ state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
+
+ pt_state_dict.update(state_dict)
+
+ logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters")
+
+ return load_pytorch_weights_in_tf2_model(
+ tf_model,
+ pt_state_dict,
+ tf_inputs=tf_inputs,
+ allow_missing_keys=allow_missing_keys,
+ output_loading_info=output_loading_info,
+ _prefix=_prefix,
+ tf_to_pt_weight_rename=tf_to_pt_weight_rename,
+ )
+
+
+def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_missing_keys=False):
+ """Load pytorch checkpoints in a TF 2.0 model"""
+ pt_state_dict = pt_model.state_dict()
+
+ return load_pytorch_weights_in_tf2_model(
+ tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys
+ )
+
+
+def load_pytorch_weights_in_tf2_model(
+ tf_model,
+ pt_state_dict,
+ tf_inputs=None,
+ allow_missing_keys=False,
+ output_loading_info=False,
+ _prefix=None,
+ tf_to_pt_weight_rename=None,
+):
+ """Load pytorch state_dict in a TF 2.0 model."""
+ try:
+ import tensorflow as tf # noqa: F401
+ import torch # noqa: F401
+ except ImportError:
+ logger.error(
+ "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
+ "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+
+ # Numpy doesn't understand bfloat16, so upcast to a dtype that doesn't lose precision
+ pt_state_dict = {
+ k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
+ }
+ return load_pytorch_state_dict_in_tf2_model(
+ tf_model,
+ pt_state_dict,
+ tf_inputs=tf_inputs,
+ allow_missing_keys=allow_missing_keys,
+ output_loading_info=output_loading_info,
+ _prefix=_prefix,
+ tf_to_pt_weight_rename=tf_to_pt_weight_rename,
+ )
+
+
+def _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name):
+ if len(unexpected_keys) > 0:
+ logger.warning(
+ "Some weights of the PyTorch model were not used when initializing the TF 2.0 model"
+ f" {class_name}: {unexpected_keys}\n- This IS expected if you are initializing"
+ f" {class_name} from a PyTorch model trained on another task or with another architecture"
+ " (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS"
+ f" NOT expected if you are initializing {class_name} from a PyTorch model that you expect"
+ " to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a"
+ " BertForSequenceClassification model)."
+ )
+ else:
+ logger.warning(f"All PyTorch model weights were used when initializing {class_name}.\n")
+ if len(missing_keys) > 0:
+ logger.warning(
+ f"Some weights or buffers of the TF 2.0 model {class_name} were not initialized from the"
+ f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a"
+ " down-stream task to be able to use it for predictions and inference."
+ )
+ else:
+ logger.warning(
+ f"All the weights of {class_name} were initialized from the PyTorch model.\n"
+ "If your task is similar to the task the model of the checkpoint was trained on, "
+ f"you can already use {class_name} for predictions without further training."
+ )
+
+ if len(mismatched_keys) > 0:
+ mismatched_warning = "\n".join(
+ [
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
+ for key, shape1, shape2 in mismatched_keys
+ ]
+ )
+ logger.warning(
+ f"Some weights of {class_name} were not initialized from the model checkpoint"
+ f" are newly initialized because the shapes did not"
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
+ " to use it for predictions and inference."
+ )
+
+
+def load_pytorch_state_dict_in_tf2_model(
+ tf_model,
+ pt_state_dict,
+ tf_inputs=None,
+ allow_missing_keys=False,
+ output_loading_info=False,
+ _prefix=None,
+ tf_to_pt_weight_rename=None,
+ ignore_mismatched_sizes=False,
+ skip_logger_warnings=False,
+):
+ """Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
+ safetensors archive created with the safe_open() function."""
+ import tensorflow as tf
+
+ if tf_inputs is None:
+ tf_inputs = tf_model.dummy_inputs
+
+ if _prefix is None:
+ _prefix = ""
+ if tf_inputs:
+ with tf.name_scope(_prefix):
+ tf_model(tf_inputs, training=False) # Make sure model is built
+ # Convert old format to new format if needed from a PyTorch state_dict
+ tf_keys_to_pt_keys = {}
+ for key in pt_state_dict:
+ new_key = None
+ if "gamma" in key:
+ new_key = key.replace("gamma", "weight")
+ if "beta" in key:
+ new_key = key.replace("beta", "bias")
+ if "running_var" in key:
+ new_key = key.replace("running_var", "moving_variance")
+ if "running_mean" in key:
+ new_key = key.replace("running_mean", "moving_mean")
+
+ # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
+ key_components = key.split(".")
+ name = None
+ if key_components[-3::2] == ["parametrizations", "original0"]:
+ name = key_components[-2] + "_g"
+ elif key_components[-3::2] == ["parametrizations", "original1"]:
+ name = key_components[-2] + "_v"
+ if name is not None:
+ key_components = key_components[:-3] + [name]
+ new_key = ".".join(key_components)
+
+ if new_key is None:
+ new_key = key
+ tf_keys_to_pt_keys[new_key] = key
+
+ # Matt: All TF models store the actual model stem in a MainLayer class, including the base model.
+ # In PT, the derived models (with heads) use the base model class as the stem instead,
+ # and there is no MainLayer class. This means that TF base classes have one
+ # extra layer in their weight names, corresponding to the MainLayer class. This code block compensates for that.
+ start_prefix_to_remove = ""
+ if not any(s.startswith(tf_model.base_model_prefix) for s in tf_keys_to_pt_keys):
+ start_prefix_to_remove = tf_model.base_model_prefix + "."
+
+ symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
+ tf_loaded_numel = 0
+ all_pytorch_weights = set(tf_keys_to_pt_keys.keys())
+ missing_keys = []
+ mismatched_keys = []
+ is_safetensor_archive = hasattr(pt_state_dict, "get_tensor")
+ for symbolic_weight in symbolic_weights:
+ sw_name = symbolic_weight.name
+ name, transpose = convert_tf_weight_name_to_pt_weight_name(
+ sw_name,
+ start_prefix_to_remove=start_prefix_to_remove,
+ tf_weight_shape=symbolic_weight.shape,
+ name_scope=_prefix,
+ )
+ if tf_to_pt_weight_rename is not None:
+ aliases = tf_to_pt_weight_rename(name) # Is a tuple to account for possible name aliasing
+ for alias in aliases: # The aliases are in priority order, take the first one that matches
+ if alias in tf_keys_to_pt_keys:
+ name = alias
+ break
+ else:
+ # If none of the aliases match, just use the first one (it'll be reported as missing)
+ name = aliases[0]
+
+ # Find associated numpy array in pytorch model state dict
+ if name not in tf_keys_to_pt_keys:
+ if allow_missing_keys:
+ missing_keys.append(name)
+ continue
+ elif tf_model._keys_to_ignore_on_load_missing is not None:
+ # authorized missing keys don't have to be loaded
+ if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing):
+ continue
+ raise AttributeError(f"{name} not found in PyTorch model")
+ state_dict_name = tf_keys_to_pt_keys[name]
+ if is_safetensor_archive:
+ array = pt_state_dict.get_tensor(state_dict_name)
+ else:
+ array = pt_state_dict[state_dict_name]
+ try:
+ array = apply_transpose(transpose, array, symbolic_weight.shape)
+ except tf.errors.InvalidArgumentError as e:
+ if not ignore_mismatched_sizes:
+ error_msg = str(e)
+ error_msg += (
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
+ )
+ raise tf.errors.InvalidArgumentError(error_msg)
+ else:
+ mismatched_keys.append((name, array.shape, symbolic_weight.shape))
+ continue
+
+ tf_loaded_numel += tensor_size(array)
+
+ symbolic_weight.assign(tf.cast(array, symbolic_weight.dtype))
+ del array # Immediately free memory to keep peak usage as low as possible
+ all_pytorch_weights.discard(name)
+
+ logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.")
+
+ unexpected_keys = list(all_pytorch_weights)
+
+ if tf_model._keys_to_ignore_on_load_missing is not None:
+ for pat in tf_model._keys_to_ignore_on_load_missing:
+ missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
+ if tf_model._keys_to_ignore_on_load_unexpected is not None:
+ for pat in tf_model._keys_to_ignore_on_load_unexpected:
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
+ if not skip_logger_warnings:
+ _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__)
+
+ if output_loading_info:
+ loading_info = {
+ "missing_keys": missing_keys,
+ "unexpected_keys": unexpected_keys,
+ "mismatched_keys": mismatched_keys,
+ }
+ return tf_model, loading_info
+
+ return tf_model
+
+
+def load_sharded_pytorch_safetensors_in_tf2_model(
+ tf_model,
+ safetensors_shards,
+ tf_inputs=None,
+ allow_missing_keys=False,
+ output_loading_info=False,
+ _prefix=None,
+ tf_to_pt_weight_rename=None,
+ ignore_mismatched_sizes=False,
+):
+ all_loading_infos = []
+ for shard in safetensors_shards:
+ with safe_open(shard, framework="tf") as safetensors_archive:
+ tf_model, loading_info = load_pytorch_state_dict_in_tf2_model(
+ tf_model,
+ safetensors_archive,
+ tf_inputs=tf_inputs,
+ allow_missing_keys=allow_missing_keys,
+ output_loading_info=True,
+ _prefix=_prefix,
+ tf_to_pt_weight_rename=tf_to_pt_weight_rename,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ skip_logger_warnings=True, # We will emit merged warnings at the end
+ )
+ all_loading_infos.append(loading_info)
+ # Now we just need to merge the loading info
+ # Keys are missing only if they're missing in *every* shard
+ missing_keys = sorted(set.intersection(*[set(info["missing_keys"]) for info in all_loading_infos]))
+ # Keys are unexpected/mismatched if they're unexpected/mismatched in *any* shard
+ unexpected_keys = sum([info["unexpected_keys"] for info in all_loading_infos], [])
+ mismatched_keys = sum([info["mismatched_keys"] for info in all_loading_infos], [])
+
+ _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__)
+
+ if output_loading_info:
+ loading_info = {
+ "missing_keys": missing_keys,
+ "unexpected_keys": unexpected_keys,
+ "mismatched_keys": mismatched_keys,
+ }
+ return tf_model, loading_info
+
+ return tf_model
+
+
+#####################
+# TF 2.0 => PyTorch #
+#####################
+
+
+def load_tf2_checkpoint_in_pytorch_model(
+ pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
+):
+ """
+ Load TF 2.0 HDF5 checkpoint in a PyTorch model We use HDF5 to easily do transfer learning (see
+ https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
+ """
+ try:
+ import tensorflow as tf # noqa: F401
+ import torch # noqa: F401
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
+ "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+
+ import transformers
+
+ from .modeling_tf_utils import load_tf_weights
+
+ logger.info(f"Loading TensorFlow weights from {tf_checkpoint_path}")
+
+ # Instantiate and load the associated TF 2.0 model
+ tf_model_class_name = "TF" + pt_model.__class__.__name__ # Add "TF" at the beginning
+ tf_model_class = getattr(transformers, tf_model_class_name)
+ tf_model = tf_model_class(pt_model.config)
+
+ if tf_inputs is None:
+ tf_inputs = tf_model.dummy_inputs
+
+ if tf_inputs is not None:
+ tf_model(tf_inputs, training=False) # Make sure model is built
+
+ load_tf_weights(tf_model, tf_checkpoint_path)
+
+ return load_tf2_model_in_pytorch_model(
+ pt_model, tf_model, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
+ )
+
+
+def load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=False, output_loading_info=False):
+ """Load TF 2.0 model in a pytorch model"""
+ weights = tf_model.weights
+
+ return load_tf2_weights_in_pytorch_model(
+ pt_model, weights, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
+ )
+
+
+def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=False, output_loading_info=False):
+ """Load TF2.0 symbolic weights in a PyTorch model"""
+ try:
+ import tensorflow as tf # noqa: F401
+ import torch # noqa: F401
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
+ "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+
+ tf_state_dict = {tf_weight.name: tf_weight.numpy() for tf_weight in tf_weights}
+ return load_tf2_state_dict_in_pytorch_model(
+ pt_model, tf_state_dict, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
+ )
+
+
+def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_keys=False, output_loading_info=False):
+ import torch
+
+ new_pt_params_dict = {}
+ current_pt_params_dict = dict(pt_model.named_parameters())
+
+ # Make sure we are able to load PyTorch base models as well as derived models (with heads)
+ # TF models always have a prefix, some of PyTorch models (base ones) don't
+ start_prefix_to_remove = ""
+ if not any(s.startswith(pt_model.base_model_prefix) for s in current_pt_params_dict):
+ start_prefix_to_remove = pt_model.base_model_prefix + "."
+
+ # Build a map from potential PyTorch weight names to TF 2.0 Variables
+ tf_weights_map = {}
+ for name, tf_weight in tf_state_dict.items():
+ pt_name, transpose = convert_tf_weight_name_to_pt_weight_name(
+ name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape
+ )
+ tf_weights_map[pt_name] = (tf_weight, transpose)
+
+ all_tf_weights = set(tf_weights_map.keys())
+ loaded_pt_weights_data_ptr = {}
+ missing_keys_pt = []
+ for pt_weight_name, pt_weight in current_pt_params_dict.items():
+ # Handle PyTorch shared weight not duplicated in TF 2.0
+ if pt_weight.data_ptr() in loaded_pt_weights_data_ptr and pt_weight.data_ptr() != 0:
+ new_pt_params_dict[pt_weight_name] = loaded_pt_weights_data_ptr[pt_weight.data_ptr()]
+ continue
+
+ pt_weight_name_to_check = pt_weight_name
+ # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
+ key_components = pt_weight_name.split(".")
+ name = None
+ if key_components[-3::2] == ["parametrizations", "original0"]:
+ name = key_components[-2] + "_g"
+ elif key_components[-3::2] == ["parametrizations", "original1"]:
+ name = key_components[-2] + "_v"
+ if name is not None:
+ key_components = key_components[:-3] + [name]
+ pt_weight_name_to_check = ".".join(key_components)
+
+ # Find associated numpy array in pytorch model state dict
+ if pt_weight_name_to_check not in tf_weights_map:
+ if allow_missing_keys:
+ missing_keys_pt.append(pt_weight_name)
+ continue
+
+ raise AttributeError(f"{pt_weight_name} not found in TF 2.0 model")
+
+ array, transpose = tf_weights_map[pt_weight_name_to_check]
+
+ array = apply_transpose(transpose, array, pt_weight.shape, pt_to_tf=False)
+
+ if numpy.isscalar(array):
+ array = numpy.array(array)
+ if not is_torch_tensor(array) and not is_numpy_array(array):
+ array = array.numpy()
+ if is_numpy_array(array):
+ # Convert to torch tensor
+ array = torch.from_numpy(array)
+
+ new_pt_params_dict[pt_weight_name] = array
+ loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = array
+ all_tf_weights.discard(pt_weight_name)
+
+ missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False)
+ missing_keys += missing_keys_pt
+
+ # Some models may have keys that are not in the state by design, removing them before needlessly warning
+ # the user.
+ if pt_model._keys_to_ignore_on_load_missing is not None:
+ for pat in pt_model._keys_to_ignore_on_load_missing:
+ missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
+
+ if pt_model._keys_to_ignore_on_load_unexpected is not None:
+ for pat in pt_model._keys_to_ignore_on_load_unexpected:
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
+
+ if len(unexpected_keys) > 0:
+ logger.warning(
+ "Some weights of the TF 2.0 model were not used when initializing the PyTorch model"
+ f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
+ f" {pt_model.__class__.__name__} from a TF 2.0 model trained on another task or with another architecture"
+ " (e.g. initializing a BertForSequenceClassification model from a TFBertForPreTraining model).\n- This IS"
+ f" NOT expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model that you expect"
+ " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
+ " TFBertForSequenceClassification model)."
+ )
+ else:
+ logger.warning(f"All TF 2.0 model weights were used when initializing {pt_model.__class__.__name__}.\n")
+ if len(missing_keys) > 0:
+ logger.warning(
+ f"Some weights of {pt_model.__class__.__name__} were not initialized from the TF 2.0 model and are newly"
+ f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
+ " use it for predictions and inference."
+ )
+ else:
+ logger.warning(
+ f"All the weights of {pt_model.__class__.__name__} were initialized from the TF 2.0 model.\n"
+ "If your task is similar to the task the model of the checkpoint was trained on, "
+ f"you can already use {pt_model.__class__.__name__} for predictions without further training."
+ )
+
+ logger.info(f"Weights or buffers not loaded from TF 2.0 model: {all_tf_weights}")
+
+ if output_loading_info:
+ loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
+ return pt_model, loading_info
+
+ return pt_model
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_tf_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_tf_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7bb80656d1b8596e9f3313b6f768c6d2251099d
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_tf_utils.py
@@ -0,0 +1,3529 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF general model utils."""
+
+from __future__ import annotations
+
+import functools
+import gc
+import inspect
+import json
+import os
+import pickle
+import re
+import warnings
+from collections.abc import Mapping
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Callable, Union
+
+import h5py
+import numpy as np
+import tensorflow as tf
+from packaging.version import parse
+
+from . import DataCollatorWithPadding, DefaultDataCollator
+from .activations_tf import get_tf_activation
+from .configuration_utils import PretrainedConfig
+from .dynamic_module_utils import custom_object_save
+from .generation import GenerationConfig, TFGenerationMixin
+from .tf_utils import (
+ convert_batch_encoding,
+ expand_1d,
+ load_attributes_from_hdf5_group,
+ save_attributes_to_hdf5_group,
+ shape_list,
+)
+from .utils import (
+ SAFE_WEIGHTS_INDEX_NAME,
+ SAFE_WEIGHTS_NAME,
+ TF2_WEIGHTS_INDEX_NAME,
+ TF2_WEIGHTS_NAME,
+ TF_WEIGHTS_NAME,
+ WEIGHTS_INDEX_NAME,
+ WEIGHTS_NAME,
+ ModelOutput,
+ PushToHubMixin,
+ cached_file,
+ download_url,
+ find_labels,
+ has_file,
+ is_offline_mode,
+ is_remote_url,
+ is_safetensors_available,
+ is_tf_symbolic_tensor,
+ logging,
+ requires_backends,
+ working_or_temp_dir,
+)
+from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
+
+
+if is_safetensors_available():
+ from safetensors import safe_open
+ from safetensors.tensorflow import save_file as safe_save_file
+
+if TYPE_CHECKING:
+ from . import PreTrainedTokenizerBase
+
+logger = logging.get_logger(__name__)
+
+if "TF_USE_LEGACY_KERAS" not in os.environ:
+ os.environ["TF_USE_LEGACY_KERAS"] = "1" # Compatibility fix to make sure tf.keras stays at Keras 2
+elif os.environ["TF_USE_LEGACY_KERAS"] != "1":
+ logger.warning(
+ "Transformers is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. "
+ "This may result in unexpected behaviour or errors if Keras 3 objects are passed to Transformers models."
+ )
+
+try:
+ import tf_keras as keras
+ from tf_keras import backend as K
+except (ModuleNotFoundError, ImportError):
+ import keras
+ from keras import backend as K
+
+ if parse(keras.__version__).major > 2:
+ raise ValueError(
+ "Your currently installed version of Keras is Keras 3, but this is not yet supported in "
+ "Transformers. Please install the backwards-compatible tf-keras package with "
+ "`pip install tf-keras`."
+ )
+
+
+tf_logger = tf.get_logger()
+
+TFModelInputType = Union[
+ list[tf.Tensor],
+ list[np.ndarray],
+ dict[str, tf.Tensor],
+ dict[str, np.ndarray],
+ tf.Tensor,
+ np.ndarray,
+]
+
+
+def dummy_loss(y_true, y_pred):
+ if y_pred.shape.rank <= 1:
+ return y_pred
+ else:
+ reduction_axes = list(range(1, y_pred.shape.rank))
+ return tf.reduce_mean(y_pred, axis=reduction_axes)
+
+
+class TFModelUtilsMixin:
+ """
+ A few utilities for `keras.Model`, to be used as a mixin.
+ """
+
+ def num_parameters(self, only_trainable: bool = False) -> int:
+ """
+ Get the number of (optionally, trainable) parameters in the model.
+
+ Args:
+ only_trainable (`bool`, *optional*, defaults to `False`):
+ Whether or not to return only the number of trainable parameters
+
+ Returns:
+ `int`: The number of parameters.
+ """
+ if only_trainable:
+ return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
+ else:
+ return self.count_params()
+
+
+def keras_serializable(cls):
+ """
+ Decorate a Keras Layer class to support Keras serialization.
+
+ This is done by:
+
+ 1. Adding a `transformers_config` dict to the Keras config dictionary in `get_config` (called by Keras at
+ serialization time.
+ 2. Wrapping `__init__` to accept that `transformers_config` dict (passed by Keras at deserialization time) and
+ convert it to a config object for the actual layer initializer.
+ 3. Registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does not
+ need to be supplied in `custom_objects` in the call to `keras.models.load_model`.
+
+ Args:
+ cls (a `keras.layers.Layers subclass`):
+ Typically a `TF.MainLayer` class in this project, in general must accept a `config` argument to its
+ initializer.
+
+ Returns:
+ The same class object, with modifications for Keras deserialization.
+ """
+ initializer = cls.__init__
+
+ config_class = getattr(cls, "config_class", None)
+ if config_class is None:
+ raise AttributeError("Must set `config_class` to use @keras_serializable")
+
+ @functools.wraps(initializer)
+ def wrapped_init(self, *args, **kwargs):
+ config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.pop("config", None)
+
+ if isinstance(config, dict):
+ config = config_class.from_dict(config)
+ initializer(self, config, *args, **kwargs)
+ elif isinstance(config, PretrainedConfig):
+ if len(args) > 0:
+ initializer(self, *args, **kwargs)
+ else:
+ initializer(self, config, *args, **kwargs)
+ else:
+ raise TypeError("Must pass either `config` (PretrainedConfig) or `config` (dict)")
+
+ self._config = config
+ self._kwargs = kwargs
+
+ cls.__init__ = wrapped_init
+
+ if not hasattr(cls, "get_config"):
+ raise TypeError("Only use @keras_serializable on keras.layers.Layer subclasses")
+ if hasattr(cls.get_config, "_is_default"):
+
+ def get_config(self):
+ cfg = super(cls, self).get_config()
+ cfg["config"] = self._config.to_dict()
+ cfg.update(self._kwargs)
+ return cfg
+
+ cls.get_config = get_config
+
+ cls._keras_serializable = True
+ if hasattr(keras.utils, "register_keras_serializable"):
+ cls = keras.utils.register_keras_serializable()(cls)
+ return cls
+
+
+class TFCausalLanguageModelingLoss:
+ """
+ Loss function suitable for causal language modeling (CLM), that is, the task of guessing the next token.
+
+
+
+ Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
+
+
+ """
+
+ def hf_compute_loss(self, labels, logits):
+ loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
+ if self.config.tf_legacy_loss:
+ # make sure only labels that are not equal to -100 affect the loss
+ active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
+ reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
+ labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
+ return loss_fn(labels, reduced_logits)
+
+ # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
+ unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
+ # make sure only labels that are not equal to -100 affect the loss
+ loss_mask = tf.cast(labels != -100, dtype=unmasked_loss.dtype)
+ masked_loss = unmasked_loss * loss_mask
+ reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask)
+ return tf.reshape(reduced_masked_loss, (1,))
+
+
+class TFQuestionAnsweringLoss:
+ """
+ Loss function suitable for question answering.
+ """
+
+ def hf_compute_loss(self, labels, logits):
+ loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
+ start_loss = loss_fn(labels["start_position"], logits[0])
+ end_loss = loss_fn(labels["end_position"], logits[1])
+
+ return (start_loss + end_loss) / 2.0
+
+
+class TFTokenClassificationLoss:
+ """
+ Loss function suitable for token classification.
+
+
+
+ Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
+
+
+ """
+
+ def hf_compute_loss(self, labels, logits):
+ loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
+ if tf.executing_eagerly(): # Data-dependent conditionals are forbidden in XLA
+ if tf.math.reduce_any(labels == -1):
+ tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
+
+ if self.config.tf_legacy_loss:
+ # make sure only labels that are not equal to -100
+ # are taken into account as loss
+ if tf.math.reduce_any(labels == -1):
+ tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
+ active_loss = tf.reshape(labels, (-1,)) != -1
+ else:
+ active_loss = tf.reshape(labels, (-1,)) != -100
+ reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
+ labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
+
+ return loss_fn(labels, reduced_logits)
+
+ # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
+ unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
+ # make sure only labels that are not equal to -100 or -1
+ # are taken into account as loss
+ loss_mask = tf.cast(labels >= 0, dtype=unmasked_loss.dtype)
+ # Avoid possible division by zero later
+ # Masked positions will have a loss of NaN because -100 and -1 are not valid labels
+ masked_loss = unmasked_loss * loss_mask
+ reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask)
+ return tf.reshape(reduced_masked_loss, (1,))
+
+
+class TFSequenceClassificationLoss:
+ """
+ Loss function suitable for sequence classification.
+ """
+
+ def hf_compute_loss(self, labels, logits):
+ if logits.shape.rank == 1 or logits.shape[1] == 1:
+ loss_fn = keras.losses.MeanSquaredError(reduction=keras.losses.Reduction.NONE)
+ if labels.shape.rank == 1:
+ # MeanSquaredError returns a scalar loss if the labels are 1D, so avoid that
+ labels = tf.expand_dims(labels, axis=-1)
+ else:
+ loss_fn = keras.losses.SparseCategoricalCrossentropy(
+ from_logits=True, reduction=keras.losses.Reduction.NONE
+ )
+
+ return loss_fn(labels, logits)
+
+
+class TFMultipleChoiceLoss:
+ """Loss function suitable for multiple choice tasks."""
+
+ def hf_compute_loss(self, labels, logits):
+ loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
+ return loss_fn(labels, logits)
+
+
+class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
+ """
+ Loss function suitable for masked language modeling (MLM), that is, the task of guessing the masked tokens.
+
+
+
+ Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
+
+
+ """
+
+
+class TFNextSentencePredictionLoss:
+ """
+ Loss function suitable for next sentence prediction (NSP), that is, the task of guessing the next sentence.
+
+
+
+ Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
+
+
+ """
+
+ def hf_compute_loss(self, labels, logits):
+ loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
+ if self.config.tf_legacy_loss:
+ # make sure only labels that are not equal to -100
+ # are taken into account as loss
+ next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
+ next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
+ next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)
+
+ return loss_fn(next_sentence_label, next_sentence_reduced_logits)
+
+ # make sure only labels that are not equal to -100
+ # are taken into account as loss
+
+ # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
+ unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels), y_pred=logits)
+ ns_loss_mask = tf.cast(labels != -100, dtype=unmasked_ns_loss.dtype)
+ # Just zero out samples where label is -100, no reduction
+ masked_ns_loss = unmasked_ns_loss * ns_loss_mask
+
+ return masked_ns_loss
+
+
+def booleans_processing(config, **kwargs):
+ """
+ Process the input booleans of each model.
+
+ Args:
+ config ([`PretrainedConfig`]):
+ The config of the running model.
+ **kwargs:
+ The boolean parameters
+
+ Returns:
+ A dictionary with the proper values for each boolean
+ """
+ final_booleans = {}
+
+ # Pure conv models (such as ConvNext) do not have `output_attentions`. If the signature has
+ # `output_attentions`, it will be present here in `kwargs`, even if unset (in that case, as `None`)
+ if "output_attentions" in kwargs:
+ final_booleans["output_attentions"] = (
+ kwargs["output_attentions"] if kwargs["output_attentions"] is not None else config.output_attentions
+ )
+ final_booleans["output_hidden_states"] = (
+ kwargs["output_hidden_states"] if kwargs["output_hidden_states"] is not None else config.output_hidden_states
+ )
+ final_booleans["return_dict"] = kwargs["return_dict"] if kwargs["return_dict"] is not None else config.return_dict
+
+ if "use_cache" in kwargs:
+ final_booleans["use_cache"] = (
+ kwargs["use_cache"] if kwargs["use_cache"] is not None else getattr(config, "use_cache", None)
+ )
+ return final_booleans
+
+
+def unpack_inputs(func):
+ """
+ Decorator that processes the inputs to a Keras layer, passing them to the layer as keyword arguments. This enables
+ downstream use of the inputs by their variable name, even if they arrive packed as a dictionary in the first input
+ (common case in Keras).
+
+ Args:
+ func (`callable`):
+ The callable function of the TensorFlow model.
+
+
+ Returns:
+ A callable that wraps the original `func` with the behavior described above.
+ """
+
+ original_signature = inspect.signature(func)
+
+ @functools.wraps(func)
+ def run_call_with_unpacked_inputs(self, *args, **kwargs):
+ # isolates the actual `**kwargs` for the decorated function
+ kwargs_call = {key: val for key, val in kwargs.items() if key not in dict(original_signature.parameters)}
+ fn_args_and_kwargs = {key: val for key, val in kwargs.items() if key not in kwargs_call}
+ fn_args_and_kwargs.update({"kwargs_call": kwargs_call})
+
+ # move any arg into kwargs, if they exist
+ fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))
+
+ # Encoder Decoder models delegate the application of the configuration options to their inner models.
+ if "EncoderDecoder" in self.__class__.__name__:
+ config = None
+ else:
+ config = self.config
+
+ unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs)
+ return func(self, **unpacked_inputs)
+
+ # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
+ # function does not follow wrapper chains (i.e. ignores `functools.wraps()`), meaning that without the line below
+ # Keras would attempt to check the first argument against the literal signature of the wrapper.
+ run_call_with_unpacked_inputs.__signature__ = original_signature
+
+ return run_call_with_unpacked_inputs
+
+
+def input_processing(func, config, **kwargs):
+ """
+ Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
+ has to be named accordingly to the parameters name, i.e. `input_ids = keras.Input(shape=(128,), dtype='int32',
+ name="input_ids")` otherwise the order of the tensors will not be guaranteed during the training.
+
+ Args:
+ func (`callable`):
+ The callable function of the TensorFlow model.
+ config ([`PretrainedConfig`]):
+ The config of the running model.
+ **kwargs:
+ The inputs of the model.
+
+ Returns:
+ Two lists, one for the missing layers, and another one for the unexpected layers.
+ """
+ signature = dict(inspect.signature(func).parameters)
+ has_kwargs = bool(signature.pop("kwargs", None))
+ signature.pop("self", None)
+ parameter_names = list(signature.keys())
+ main_input_name = parameter_names[0]
+ main_input = kwargs.pop(main_input_name, None)
+ output = {}
+ allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
+
+ if "inputs" in kwargs["kwargs_call"]:
+ warnings.warn(
+ "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
+ FutureWarning,
+ )
+
+ output["input_ids"] = kwargs["kwargs_call"].pop("inputs")
+
+ if "decoder_cached_states" in kwargs["kwargs_call"]:
+ warnings.warn(
+ "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
+ " `past_key_values` instead.",
+ FutureWarning,
+ )
+ output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")
+
+ if "past" in kwargs["kwargs_call"] and "past_key_values" in parameter_names:
+ warnings.warn(
+ "The `past` argument is deprecated and will be removed in a future version, use `past_key_values`"
+ " instead.",
+ FutureWarning,
+ )
+ kwargs["past_key_values"] = kwargs["kwargs_call"].pop("past")
+ elif "past_key_values" in kwargs["kwargs_call"] and "past" in parameter_names:
+ kwargs["past"] = kwargs["kwargs_call"].pop("past_key_values")
+
+ if has_kwargs:
+ output["kwargs"] = kwargs.pop("kwargs_call", {})
+ else:
+ if len(kwargs["kwargs_call"]) > 0:
+ raise ValueError(
+ "The following keyword arguments are not supported by this model:"
+ f" {list(kwargs['kwargs_call'].keys())}."
+ )
+ kwargs.pop("kwargs_call")
+
+ for k, v in kwargs.items():
+ if isinstance(v, allowed_types) or tf.is_tensor(v) or v is None:
+ output[k] = v
+ else:
+ raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
+
+ if isinstance(main_input, (tuple, list)):
+ for i, input in enumerate(main_input):
+ # EagerTensors don't allow to use the .name property so we check for a real Tensor
+ if is_tf_symbolic_tensor(input):
+ # Tensor names have always the pattern `name:id` then we check only the
+ # `name` part
+ tensor_name = input.name.split(":")[0]
+
+ if tensor_name in parameter_names:
+ output[tensor_name] = input
+ else:
+ output[parameter_names[i]] = input
+ elif isinstance(input, allowed_types) or input is None:
+ output[parameter_names[i]] = input
+ else:
+ raise ValueError(
+ f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for"
+ f" {parameter_names[i]}."
+ )
+ elif isinstance(main_input, Mapping):
+ if "inputs" in main_input:
+ warnings.warn(
+ "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids`"
+ " instead.",
+ FutureWarning,
+ )
+
+ output["input_ids"] = main_input.pop("inputs")
+
+ if "decoder_cached_states" in main_input:
+ warnings.warn(
+ "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
+ " `past_key_values` instead.",
+ FutureWarning,
+ )
+ output["past_key_values"] = main_input.pop("decoder_cached_states")
+
+ for k, v in dict(main_input).items():
+ if isinstance(v, allowed_types) or v is None:
+ output[k] = v
+ elif k not in parameter_names and "args" not in parameter_names:
+ logger.warning(
+ f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored."
+ )
+ continue
+ else:
+ raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
+ else:
+ if tf.is_tensor(main_input) or main_input is None:
+ output[main_input_name] = main_input
+ else:
+ raise ValueError(
+ f"Data of type {type(main_input)} is not allowed only {allowed_types} is accepted for"
+ f" {main_input_name}."
+ )
+
+ # Populates any unspecified argument with their default value, according to the signature.
+ for name in parameter_names:
+ if name not in list(output.keys()) and name != "args":
+ output[name] = kwargs.pop(name, signature[name].default)
+
+ # When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
+ # So to respect the proper output we have to add this exception
+ if "args" in output:
+ if output["args"] is not None and is_tf_symbolic_tensor(output["args"]):
+ tensor_name = output["args"].name.split(":")[0]
+ output[tensor_name] = output["args"]
+ else:
+ # `args` in this case is always the first parameter, then `input_ids`
+ output["input_ids"] = output["args"]
+
+ del output["args"]
+
+ if "kwargs" in output:
+ del output["kwargs"]
+
+ cast_output = {}
+ for key, val in output.items():
+ if isinstance(val, tf.Tensor) and val.dtype == tf.int64:
+ cast_output[key] = tf.cast(val, tf.int32)
+ elif isinstance(val, np.ndarray) and val.dtype == np.int64:
+ cast_output[key] = val.astype(np.int32)
+ else:
+ cast_output[key] = val
+
+ output = cast_output
+ del cast_output
+
+ if config is not None:
+ boolean_dict = {
+ k: v
+ for k, v in output.items()
+ if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
+ }
+
+ output.update(
+ booleans_processing(
+ config=config,
+ **boolean_dict,
+ )
+ )
+
+ return output
+
+
+def strip_model_name_and_prefix(name, _prefix=None):
+ if _prefix is not None and name.startswith(_prefix):
+ name = name[len(_prefix) :]
+ if name.startswith("/"):
+ name = name[1:]
+ if "model." not in name and len(name.split("/")) > 1:
+ name = "/".join(name.split("/")[1:])
+ return name
+
+
+def tf_shard_checkpoint(weights, max_shard_size="10GB", weights_name: str = TF2_WEIGHTS_NAME):
+ """
+ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
+ given size.
+
+ The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
+ optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
+ limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
+ [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
+
+
+
+ If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will
+ have a size greater than `max_shard_size`.
+
+
+
+ Args:
+ weights (`dict[str, tf.RessourceVariable]`): The list of tf.RessourceVariable of a model to save.
+ max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
+ The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
+ (like `"5MB"`).
+ """
+ max_shard_size = convert_file_size_to_int(max_shard_size)
+
+ sharded_state_dicts = []
+ current_block = []
+ current_block_size = 0
+ total_size = 0
+
+ for item in weights:
+ weight_size = item.numpy().size * item.dtype.size
+
+ # If this weight is going to tip up over the maximal size, we split.
+ if current_block_size + weight_size > max_shard_size:
+ sharded_state_dicts.append(current_block)
+ current_block = []
+ current_block_size = 0
+
+ current_block.append(item)
+ current_block_size += weight_size
+ total_size += weight_size
+
+ # Add the last block
+ sharded_state_dicts.append(current_block)
+
+ # If we only have one shard, we return it
+ if len(sharded_state_dicts) == 1:
+ return {weights_name: sharded_state_dicts[0]}, None
+
+ # Otherwise, let's build the index
+ weight_map = {}
+ shards = {}
+ for idx, shard in enumerate(sharded_state_dicts):
+ shard_file = weights_name.replace(".h5", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.h5")
+ shard_file = shard_file.replace(
+ ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
+ )
+ shards[shard_file] = shard
+ for weight in shard:
+ weight_name = weight.name
+ weight_map[weight_name] = shard_file
+
+ # Add the metadata
+ metadata = {"total_size": total_size}
+ index = {"metadata": metadata, "weight_map": weight_map}
+ return shards, index
+
+
+def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None):
+ """
+ This is the same as `load_tf_weights` but for a sharded checkpoint. Detect missing and unexpected layers and load
+ the TF weights from the shard file accordingly to their names and shapes.
+
+ This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
+ loaded in the model.
+
+ Args:
+ model (`keras.models.Model`): The model in which to load the checkpoint.
+ shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names.
+ ignore_mismatched_sizes`bool`, *optional`, defaults to `True`):
+ Whether or not to ignore the mismatch between the sizes
+ strict (`bool`, *optional*, defaults to `True`):
+ Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
+
+ Returns:
+ Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
+ mismatched layers.
+ """
+
+ # Load the index
+ unexpected_keys = set()
+ saved_keys = set()
+ mismatched_keys = set()
+
+ # Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
+ # the weight, we have to get rid of the first prefix of the name of the layer.
+ model_keys = set()
+ model_layer_map = {}
+ for i, k in enumerate(model.weights):
+ layer_name = k.name
+ if _prefix is not None and layer_name.startswith(_prefix):
+ layer_name = layer_name[len(_prefix) :]
+ layer_name = layer_name.lstrip("/")
+ if not ("model." in layer_name or len(layer_name.split("/")) == 1):
+ layer_name = "/".join(layer_name.split("/")[1:])
+ model_keys.add(layer_name)
+ model_layer_map[layer_name] = i
+
+ for shard_file in shard_files:
+ saved_weight_names_set, unexpected_keys_set, mismatched_keys_set = load_tf_shard(
+ model,
+ model_layer_map,
+ shard_file,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ _prefix=_prefix,
+ )
+ saved_keys.update(saved_weight_names_set)
+ unexpected_keys.update(unexpected_keys_set)
+ mismatched_keys.update(mismatched_keys_set)
+ gc.collect()
+
+ missing_keys = model_keys - saved_keys
+ if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
+ error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
+ if len(missing_keys) > 0:
+ str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
+ error_message += f"\nMissing key(s): {str_missing_keys}."
+ if len(unexpected_keys) > 0:
+ str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
+ error_message += f"\nMissing key(s): {str_unexpected_keys}."
+ raise RuntimeError(error_message)
+
+ return missing_keys, unexpected_keys, mismatched_keys
+
+
+def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
+ """
+ Loads a shard from a sharded checkpoint file. Can be either H5 or Safetensors.
+ Handles missing keys and unexpected keys.
+
+ Args:
+ model (`keras.models.Model`): Model in which the weights are loaded
+ model_layer_map (`Dict`): A dictionary mapping the layer name to the index of the layer in the model.
+ resolved_archive_file (`str`): Path to the checkpoint file from which the weights will be loaded
+ ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): Whether to ignore the mismatched keys
+
+ Returns:
+ `keras.models.Model`: Three lists, one for the layers that were found and successfully restored (from the
+ shard file), one for the mismatched layers, and another one for the unexpected layers.
+ """
+ saved_weight_names_set = set()
+ saved_weights = {}
+ mismatched_keys = set()
+ unexpected_keys = set()
+ # Read the H5 file
+ try:
+ with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
+ # Retrieve the name of each layer from the H5 file
+ saved_h5_model_layers_name = set(load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names"))
+ weight_value_tuples = []
+
+ # Compute missing and unexpected sub layers
+ # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...]
+ for layer_name in saved_h5_model_layers_name:
+ h5_layer_object = sharded_checkpoint_file[layer_name]
+ saved_weights[layer_name] = np.asarray(h5_layer_object)
+
+ saved_weight_names_set.add(layer_name)
+
+ if layer_name not in model_layer_map:
+ unexpected_keys.add(layer_name)
+ else:
+ symbolic_weight = model.weights[model_layer_map[layer_name]]
+
+ saved_weight_value = saved_weights[layer_name]
+ # If the current weight is found
+ if saved_weight_value is not None:
+ # Check if the shape of the current weight and the one from the H5 file are different
+ if K.int_shape(symbolic_weight) != saved_weight_value.shape:
+ # If yes we reshape the weight from the H5 file accordingly to the current weight
+ # If the two shapes are not compatible we raise an issue
+ try:
+ array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
+ except ValueError as e:
+ if ignore_mismatched_sizes:
+ mismatched_keys.add(
+ (layer_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
+ )
+ continue
+ else:
+ raise e
+ else:
+ array = saved_weight_value
+
+ # We create the tuple that will be loaded and add it to the final list
+ weight_value_tuples.append((symbolic_weight, array))
+
+ K.batch_set_value(weight_value_tuples)
+
+ return saved_weight_names_set, unexpected_keys, mismatched_keys
+
+ except Exception as e:
+ try:
+ with open(resolved_archive_file) as f:
+ if f.read().startswith("version"):
+ raise OSError(
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
+ "you cloned."
+ )
+ else:
+ raise ValueError(
+ f"Unable to locate the file {resolved_archive_file} which is necessary to load this pretrained"
+ " model. Make sure you have saved the model properly."
+ ) from e
+ except (UnicodeDecodeError, ValueError):
+ raise OSError(
+ f"Unable to load weights from TF checkpoint file for '{resolved_archive_file}' "
+ f"at '{resolved_archive_file}'. "
+ "If you tried to load a TF model from a sharded checkpoint, you should try converting the model "
+ "by loading it in pytorch and saving it locally. A conversion script should be released soon."
+ )
+
+
+def load_tf_sharded_weights_from_safetensors(
+ model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None
+):
+ """
+ This is the same as `load_tf_weights_from_safetensors` but for a sharded TF-format safetensors checkpoint.
+ Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
+ shapes.
+
+ This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
+ loaded in the model.
+
+ Args:
+ model (`keras.models.Model`): The model in which to load the checkpoint.
+ shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names.
+ ignore_mismatched_sizes`bool`, *optional`, defaults to `True`):
+ Whether or not to ignore the mismatch between the sizes
+ strict (`bool`, *optional*, defaults to `True`):
+ Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
+
+ Returns:
+ Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
+ mismatched layers.
+ """
+
+ # Load the index
+ unexpected_keys = set()
+ all_missing_keys = []
+ mismatched_keys = set()
+
+ for shard_file in shard_files:
+ missing_layers, unexpected_layers, mismatched_layers = load_tf_weights_from_safetensors(
+ model,
+ shard_file,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ _prefix=_prefix,
+ )
+ all_missing_keys.append(set(missing_layers))
+ unexpected_keys.update(unexpected_layers)
+ mismatched_keys.update(mismatched_layers)
+ gc.collect()
+ missing_keys = set.intersection(*all_missing_keys)
+
+ if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
+ error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
+ if len(missing_keys) > 0:
+ str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
+ error_message += f"\nMissing key(s): {str_missing_keys}."
+ if len(unexpected_keys) > 0:
+ str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
+ error_message += f"\nMissing key(s): {str_unexpected_keys}."
+ raise RuntimeError(error_message)
+
+ return missing_keys, unexpected_keys, mismatched_keys
+
+
+def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
+ """
+ Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
+ shapes.
+
+ Args:
+ model (`keras.models.Model`):
+ The model to load the weights into.
+ resolved_archive_file (`str`):
+ The location of the H5 file.
+ ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
+ Whether or not to ignore weights with shapes that don't match between the checkpoint of the model.
+
+ Returns:
+ Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
+ mismatched layers.
+ """
+ if resolved_archive_file.endswith(".safetensors"):
+ load_function = load_tf_weights_from_safetensors
+ else:
+ load_function = load_tf_weights_from_h5
+
+ return load_function(
+ model, resolved_archive_file, ignore_mismatched_sizes=ignore_mismatched_sizes, _prefix=_prefix
+ )
+
+
+def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
+ mismatched_layers = []
+
+ # Read the H5 file
+ with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
+ # Retrieve the name of each layer from the H5 file
+ saved_h5_model_layers_name = set(load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names"))
+
+ # Find the missing layers from the high level list of layers
+ missing_layers = list({layer.name for layer in model.layers} - saved_h5_model_layers_name)
+
+ # Find the unexpected layers from the high level list of layers
+ unexpected_layers = list(saved_h5_model_layers_name - {layer.name for layer in model.layers})
+ saved_weight_names_set = set()
+ symbolic_weights_names = set()
+ weight_value_tuples = []
+
+ # Compute missing and unexpected sub layers
+ # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...]
+ for layer in model.layers:
+ # if layer_name from the H5 file belongs to the layers from the instantiated model
+ if layer.name in saved_h5_model_layers_name:
+ # Get the H5 layer object from its name
+ h5_layer_object = sharded_checkpoint_file[layer.name]
+ # Get all the weights as a list from the layer object
+ symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
+ saved_weights = {}
+
+ # Create a dict from the H5 saved model that looks like {"weight_name": weight_value}
+ # And a set with only the names
+ for weight_name in load_attributes_from_hdf5_group(h5_layer_object, "weight_names"):
+ # TF names always start with the model name so we ignore it
+ name = "/".join(weight_name.split("/")[1:])
+
+ if _prefix is not None:
+ name = _prefix + "/" + name
+
+ saved_weights[name] = np.asarray(h5_layer_object[weight_name])
+
+ # Add the updated name to the final list for computing missing/unexpected values
+ saved_weight_names_set.add(name)
+
+ # Loop over each weights from the instantiated model and compare with the weights from the H5 file
+ for symbolic_weight in symbolic_weights:
+ # TF names always start with the model name so we ignore it
+ if _prefix is not None:
+ delimiter = len(_prefix.split("/"))
+ symbolic_weight_name = "/".join(
+ symbolic_weight.name.split("/")[:delimiter]
+ + symbolic_weight.name.split("/")[delimiter + 1 :]
+ )
+ else:
+ symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:])
+
+ # here we check if the current weight is among the weights from the H5 file
+ # If yes, get the weight_value of the corresponding weight from the H5 file
+ # If not, make the value to None
+ saved_weight_value = saved_weights.get(symbolic_weight_name)
+
+ # Retrocompatibility patch: some embeddings are stored with the weights name (e.g. Bart's
+ # `model.shared/embeddings:0` are stored as `model.shared/weights:0`)
+ if saved_weight_value is None and symbolic_weight_name.endswith("embeddings:0"):
+ symbolic_weight_name = symbolic_weight_name[:-12] + "weight:0"
+ saved_weight_value = saved_weights.get(symbolic_weight_name)
+
+ # Add the updated name to the final list for computing missing/unexpected values
+ symbolic_weights_names.add(symbolic_weight_name)
+
+ # If the current weight is found
+ if saved_weight_value is not None:
+ # Check if the shape of the current weight and the one from the H5 file are different
+ if K.int_shape(symbolic_weight) != saved_weight_value.shape:
+ # If yes we reshape the weight from the H5 file accordingly to the current weight
+ # If the two shapes are not compatible we raise an issue
+ try:
+ array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
+ except ValueError as e:
+ if ignore_mismatched_sizes:
+ mismatched_layers.append(
+ (symbolic_weight_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
+ )
+ continue
+ else:
+ raise e
+ else:
+ array = saved_weight_value
+
+ # We create the tuple that will be loaded and add it to the final list
+ weight_value_tuples.append((symbolic_weight, array))
+
+ # Load all the weights
+ K.batch_set_value(weight_value_tuples)
+
+ # Compute the missing and unexpected layers
+ missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
+ unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))
+
+ return missing_layers, unexpected_layers, mismatched_layers
+
+
+def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
+ # Read the safetensors file
+ with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
+ mismatched_layers = []
+ weight_names = [strip_model_name_and_prefix(w.name, _prefix=_prefix) for w in model.weights]
+ loaded_weight_names = list(safetensors_archive.keys())
+ # Find the missing layers from the high level list of layers
+ missing_layers = list(set(weight_names) - set(loaded_weight_names))
+ # Find the unexpected layers from the high level list of layers
+ unexpected_layers = list(set(loaded_weight_names) - set(weight_names))
+
+ for weight in model.weights:
+ weight_name = strip_model_name_and_prefix(weight.name, _prefix=_prefix)
+ if weight_name in loaded_weight_names:
+ weight_value = safetensors_archive.get_tensor(weight_name)
+ # Check if the shape of the current weight and the one from the H5 file are different
+ if K.int_shape(weight) != weight_value.shape:
+ # If yes we reshape the weight from the H5 file accordingly to the current weight
+ # If the two shapes are not compatible we raise an issue
+ try:
+ weight_value = tf.reshape(weight_value, K.int_shape(weight))
+ except (ValueError, tf.errors.InvalidArgumentError) as e:
+ if ignore_mismatched_sizes:
+ mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight)))
+ continue
+ else:
+ raise e
+
+ K.set_value(weight, weight_value) # weight.assign() might break if weight is a DTensor
+ return missing_layers, unexpected_layers, mismatched_layers
+
+
+def init_copy_embeddings(old_embeddings, new_num_tokens):
+ r"""
+ This function aims to reduce the embeddings in case new_num_tokens < old_num_tokens or to pad with -1 in case
+ new_num_tokens > old_num_tokens. A mask is also computed in order to know which weight in the embeddings should be
+ kept or not. Example:
+
+ - if new_num_tokens=5 and old_num_tokens=4 and old_embeddings=[w1,w2,w3,w4]
+
+ - mask=[True,True,True,True,False] and current_weights=[w1,w2,w3,w4,-1]
+ - if new_num_tokens=4 and old_num_tokens=5 and old_embeddings=[w1,w2,w3,w4,w5]
+
+ - mask=[True,True,True,True] and current_weights=[w1,w2,w3,w4]
+ """
+ old_num_tokens, old_embedding_dim = shape_list(old_embeddings)
+ size_diff = new_num_tokens - old_num_tokens
+
+ # initialize new embeddings
+ # Copy token embeddings from the previous ones
+ if tf.math.greater(size_diff, 0):
+ # if the new size is greater than the old one, we extend the current embeddings with a padding until getting new size
+ # and we create a mask to properly identify the padded values and be replaced by the values of the newly created
+ # embeddings
+ current_weights = tf.pad(
+ old_embeddings.value(), tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=-1
+ )
+ num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
+ mask = tf.fill(tf.convert_to_tensor([num_tokens_to_copy, 1]), True)
+ mask = tf.pad(mask, tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=False)
+ else:
+ # if the new size if lower than the old one, we take the current embeddings until the new size
+ current_weights = tf.slice(
+ old_embeddings.value(),
+ tf.convert_to_tensor([0, 0]),
+ tf.convert_to_tensor([new_num_tokens, old_embedding_dim]),
+ )
+ mask = tf.fill(tf.convert_to_tensor([new_num_tokens, 1]), True)
+
+ return mask, current_weights
+
+
+class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushToHubMixin):
+ r"""
+ Base class for all TF models.
+
+ [`TFPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
+ downloading and saving models as well as a few methods common to all models to:
+
+ - resize the input embeddings,
+ - prune heads in the self-attention heads.
+
+ Class attributes (overridden by derived classes):
+
+ - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
+ for this model architecture.
+ - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
+ classes of the same architecture adding modules on top of the base model.
+ - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
+ models, `pixel_values` for vision models and `input_values` for speech models).
+ """
+
+ config_class = None
+ base_model_prefix = ""
+ main_input_name = "input_ids"
+ _auto_class = None
+ _using_dummy_loss = None
+ _label_to_output_map = None
+
+ # a list of re pattern of tensor names to ignore from the model when loading the model weights
+ # (and avoid unnecessary warnings).
+ _keys_to_ignore_on_load_missing = None
+ # a list of re pattern of tensor names to ignore from the weights when loading the model weights
+ # (and avoid unnecessary warnings).
+ _keys_to_ignore_on_load_unexpected = None
+ _requires_load_weight_prefix = False
+
+ @property
+ def dummy_inputs(self) -> dict[str, tf.Tensor]:
+ """
+ Dummy inputs to build the network.
+
+ Returns:
+ `dict[str, tf.Tensor]`: The dummy inputs.
+ """
+ dummies = {}
+ for key, spec in self.input_signature.items():
+ # 2 is the most correct arbitrary size. I will not be taking questions
+ dummy_shape = [dim if dim is not None else 2 for dim in spec.shape]
+ if spec.shape[0] is None:
+ # But let's make the batch size 1 to save memory anyway
+ dummy_shape[0] = 1
+ dummies[key] = tf.ones(shape=dummy_shape, dtype=spec.dtype)
+ if key == "token_type_ids":
+ # Some models have token_type_ids but with a vocab_size of 1
+ dummies[key] = tf.zeros_like(dummies[key])
+ if self.config.add_cross_attention and "encoder_hidden_states" in inspect.signature(self.call).parameters:
+ if "encoder_hidden_states" not in dummies:
+ if self.main_input_name == "input_ids":
+ dummies["encoder_hidden_states"] = tf.ones(
+ shape=(1, 2, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states"
+ )
+ else:
+ raise NotImplementedError(
+ "Model has cross-attention but we couldn't infer the shape for the encoder hidden states. Please manually override dummy_inputs!"
+ )
+ return dummies
+
+ def build_in_name_scope(self):
+ with tf.name_scope(self.name):
+ self.build(input_shape=None)
+
+ @property
+ def framework(self) -> str:
+ """
+ :str: Identifies that this is a TensorFlow model.
+ """
+ return "tf"
+
+ def build(self, input_shape=None):
+ pass # This is just here to make sure we don't call the superclass build()
+
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+ if not isinstance(config, PretrainedConfig):
+ raise TypeError(
+ f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
+ "`PretrainedConfig`. To create a model from a pretrained model use "
+ f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ )
+ # Save config and origin of the pretrained weights if given in model
+ self.config = config
+ self.name_or_path = config.name_or_path
+ self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
+ self._set_save_spec(self.input_signature)
+ logger.warning_once(
+ "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We "
+ "recommend migrating to PyTorch classes or pinning your version of Transformers."
+ )
+
+ def get_config(self):
+ return self.config.to_dict()
+
+ @functools.wraps(keras.Model.fit)
+ def fit(self, *args, **kwargs):
+ args, kwargs = convert_batch_encoding(*args, **kwargs)
+ return super().fit(*args, **kwargs)
+
+ @functools.wraps(keras.Model.train_on_batch)
+ def train_on_batch(self, *args, **kwargs):
+ args, kwargs = convert_batch_encoding(*args, **kwargs)
+ return super().train_on_batch(*args, **kwargs)
+
+ @functools.wraps(keras.Model.test_on_batch)
+ def test_on_batch(self, *args, **kwargs):
+ args, kwargs = convert_batch_encoding(*args, **kwargs)
+ return super().test_on_batch(*args, **kwargs)
+
+ @functools.wraps(keras.Model.predict_on_batch)
+ def predict_on_batch(self, *args, **kwargs):
+ args, kwargs = convert_batch_encoding(*args, **kwargs)
+ return super().predict_on_batch(*args, **kwargs)
+
+ @functools.wraps(keras.Model.predict)
+ def predict(self, *args, **kwargs):
+ args, kwargs = convert_batch_encoding(*args, **kwargs)
+ return super().predict(*args, **kwargs)
+
+ @functools.wraps(keras.Model.evaluate)
+ def evaluate(self, *args, **kwargs):
+ args, kwargs = convert_batch_encoding(*args, **kwargs)
+ return super().evaluate(*args, **kwargs)
+
+ @classmethod
+ def from_config(cls, config, **kwargs):
+ if isinstance(config, PretrainedConfig):
+ return cls._from_config(config, **kwargs)
+ return cls._from_config(cls.config_class.from_dict(config, **kwargs))
+
+ @classmethod
+ def _from_config(cls, config, **kwargs):
+ """
+ All context managers that the model should be initialized under go here.
+ """
+ return cls(config, **kwargs)
+
+ def get_head_mask(self, head_mask: tf.Tensor | None, num_hidden_layers: int) -> tf.Tensor:
+ """
+ Prepare the head mask if needed.
+
+ Args:
+ head_mask (`tf.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
+ The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
+ num_hidden_layers (`int`):
+ The number of hidden layers in the model.
+
+ Returns:
+ `tf.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
+ `[None]` for each layer.
+ """
+ if head_mask is not None:
+ head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
+ else:
+ head_mask = [None] * num_hidden_layers
+
+ return head_mask
+
+ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
+ """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
+ if head_mask.shape.rank == 1:
+ head_mask = head_mask[None, None, :, None, None]
+ head_mask = tf.repeat(head_mask, repeats=num_hidden_layers, axis=0)
+ elif head_mask.shape.rank == 2:
+ head_mask = head_mask[:, None, :, None, None]
+ assert head_mask.shape.rank == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
+ head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility
+ return head_mask
+
+ @tf.function
+ def serving(self, inputs):
+ """
+ Args:
+ Method used for serving the model. Does not have a specific signature, but will be specialized as concrete
+ functions when saving with `save_pretrained`.
+ inputs (`dict[str, tf.Tensor]`):
+ The input of the saved model as a dictionary of tensors.
+ """
+ output = self.call(inputs)
+
+ return self.serving_output(output)
+
+ @property
+ def input_signature(self) -> dict[str, tf.TensorSpec]:
+ """
+ This property should return a dict mapping input names to tf.TensorSpec objects, representing the expected
+ shape and dtype for model inputs. It is used for both serving and for generating dummy inputs.
+ """
+ model_inputs = list(inspect.signature(self.call).parameters)
+ sig = {}
+ if "input_ids" in model_inputs:
+ if self.__class__.__name__.endswith("ForMultipleChoice"):
+ text_dims = 3
+ else:
+ text_dims = 2
+ for input_name in (
+ "input_ids",
+ "attention_mask",
+ "token_type_ids",
+ "decoder_input_ids",
+ "decoder_attention_mask",
+ ):
+ if input_name in model_inputs:
+ sig[input_name] = tf.TensorSpec([None] * text_dims, tf.int32, name=input_name)
+ if "pixel_values" in model_inputs:
+ pixel_values_shape = [None, None, None, None]
+ if hasattr(self.config, "vision_config"):
+ vision_config = self.config.vision_config
+ else:
+ vision_config = self.config
+ if hasattr(vision_config, "num_channels"):
+ pixel_values_shape[1] = vision_config.num_channels
+ else:
+ raise NotImplementedError(
+ "Could not infer number of channels from config, please override input_signature to specify input shapes."
+ )
+ if hasattr(vision_config, "image_size"):
+ pixel_values_shape[2] = pixel_values_shape[3] = vision_config.image_size
+ elif hasattr(vision_config, "input_size"):
+ pixel_values_shape[2] = pixel_values_shape[3] = vision_config.input_size
+ else:
+ raise NotImplementedError(
+ "Could not infer input image shape from config, please override input_signature to specify input shapes."
+ )
+ sig["pixel_values"] = tf.TensorSpec(pixel_values_shape, tf.float32, name="pixel_values")
+ if "input_features" in model_inputs:
+ raise NotImplementedError("Audio models need a manually defined input_signature")
+ return sig
+
+ def serving_output(self, output):
+ """
+ Prepare the output of the saved model. Can be overridden if specific serving modifications are required.
+ """
+ if not isinstance(output, ModelOutput):
+ return output
+ for key in output:
+ if key.endswith("hidden_states") and not getattr(self.config, "output_hidden_states", False):
+ output[key] = None
+ elif key.endswith("attentions") and not getattr(self.config, "output_attentions", False):
+ output[key] = None
+ elif key == "past_key_values" and not getattr(self.config, "use_cache", False):
+ output[key] = None
+ elif key == "cross_attentions" and not (
+ getattr(self.config, "output_attentions", False) and getattr(self.config, "add_cross_attention", False)
+ ):
+ output[key] = None
+ if isinstance(output[key], (tuple, list)):
+ try:
+ output[key] = tf.convert_to_tensor(output[key])
+ except (ValueError, tf.errors.InvalidArgumentError):
+ pass # Layers may not have the same dimensions
+ return output
+
+ @classmethod
+ def can_generate(cls) -> bool:
+ """
+ Returns whether this model can generate sequences with `.generate()`.
+
+ Returns:
+ `bool`: Whether this model can generate sequences with `.generate()`.
+ """
+ # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
+ # Alternatively, the model can also have a custom `generate` function.
+ if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
+ return False
+ return True
+
+ def get_input_embeddings(self) -> keras.layers.Layer:
+ """
+ Returns the model's input embeddings layer.
+
+ Returns:
+ `tf.Variable`: The embeddings layer mapping vocabulary to hidden states.
+ """
+ main_layer = getattr(self, self.base_model_prefix, self)
+
+ if main_layer is not self:
+ return main_layer.get_input_embeddings()
+ else:
+ raise NotImplementedError
+
+ def _save_checkpoint(self, checkpoint_dir, epoch):
+ if not os.path.isdir(checkpoint_dir):
+ os.mkdir(checkpoint_dir)
+ # We avoid tf.train.checkpoint or saving weights in TF format, even though that includes optimizer
+ # state for us, because it requires special handling for objects like custom losses, which we use
+ # internally and which users are likely to use too
+ weights_path = os.path.join(checkpoint_dir, "weights.h5")
+ self.save_weights(weights_path)
+ extra_data = {"epoch": epoch, "optimizer_state": self.optimizer.get_weights()}
+ extra_data_path = os.path.join(checkpoint_dir, "extra_data.pickle")
+ with open(extra_data_path, "wb") as f:
+ pickle.dump(extra_data, f)
+
+ def prepare_tf_dataset(
+ self,
+ dataset: datasets.Dataset, # noqa:F821
+ batch_size: int = 8,
+ shuffle: bool = True,
+ tokenizer: PreTrainedTokenizerBase | None = None,
+ collate_fn: Callable | None = None,
+ collate_fn_args: dict[str, Any] | None = None,
+ drop_remainder: bool | None = None,
+ prefetch: bool = True,
+ ):
+ """
+ Wraps a HuggingFace [`~datasets.Dataset`] as a `tf.data.Dataset` with collation and batching. This method is
+ designed to create a "ready-to-use" dataset that can be passed directly to Keras methods like `fit()` without
+ further modification. The method will drop columns from the dataset if they don't match input names for the
+ model. If you want to specify the column names to return rather than using the names that match this model, we
+ recommend using `Dataset.to_tf_dataset()` instead.
+
+ Args:
+ dataset (`Any`):
+ A [~`datasets.Dataset`] to be wrapped as a `tf.data.Dataset`.
+ batch_size (`int`, *optional*, defaults to 8):
+ The size of batches to return.
+ shuffle (`bool`, defaults to `True`):
+ Whether to return samples from the dataset in random order. Usually `True` for training datasets and
+ `False` for validation/test datasets.
+ tokenizer ([`PreTrainedTokenizerBase`], *optional*):
+ A `PreTrainedTokenizer` that will be used to pad samples to create batches. Has no effect if a specific
+ `collate_fn` is passed instead.
+ collate_fn (`Callable`, *optional*):
+ A function that collates samples from the dataset into a single batch. Defaults to
+ `DefaultDataCollator` if no `tokenizer` is supplied or `DataCollatorWithPadding` if a `tokenizer` is
+ passed.
+ collate_fn_args (`dict[str, Any]`, *optional*):
+ A dict of arguments to pass to the `collate_fn` alongside the list of samples.
+ drop_remainder (`bool`, *optional*):
+ Whether to drop the final batch, if the batch_size does not evenly divide the dataset length. Defaults
+ to the same setting as `shuffle`.
+ prefetch (`bool`, defaults to `True`):
+ Whether to add prefetching to the end of the `tf.data` pipeline. This is almost always beneficial for
+ performance, but can be disabled in edge cases.
+
+
+ Returns:
+ `Dataset`: A `tf.data.Dataset` which is ready to pass to the Keras API.
+ """
+ requires_backends(self, ["datasets"])
+ import datasets
+
+ if collate_fn is None:
+ if tokenizer is None:
+ collate_fn = DefaultDataCollator(return_tensors="np")
+ else:
+ collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="np")
+ if collate_fn_args is None:
+ collate_fn_args = {}
+
+ if not isinstance(dataset, datasets.Dataset):
+ raise TypeError("Dataset argument should be a datasets.Dataset!")
+ model_inputs = list(inspect.signature(self.call).parameters)
+ model_labels = find_labels(self.__class__)
+ if "cols_to_retain" in list(inspect.signature(dataset._get_output_signature).parameters.keys()):
+ output_signature, _ = dataset._get_output_signature(
+ dataset,
+ batch_size=None,
+ collate_fn=collate_fn,
+ collate_fn_args=collate_fn_args,
+ cols_to_retain=model_inputs,
+ )
+ else:
+ # TODO Matt: This is a workaround for older versions of datasets that are missing the `cols_to_retain`
+ # argument. We should remove this once the minimum supported version of datasets is > 2.3.2
+ unwanted_columns = [
+ feature
+ for feature in dataset.features
+ if feature not in model_inputs and feature not in ("label_ids", "label")
+ ]
+ dataset = dataset.remove_columns(unwanted_columns)
+ output_signature, _ = dataset._get_output_signature(
+ dataset, batch_size=None, collate_fn=collate_fn, collate_fn_args=collate_fn_args
+ )
+ output_columns = list(output_signature.keys())
+ feature_cols = [col for col in output_columns if col in model_inputs and col not in model_labels]
+ label_cols = [col for col in output_columns if col in model_labels]
+
+ # Backwards compatibility for older versions of datasets. Previously, if `columns` or `label_cols`
+ # were a single element list, the returned element spec would be a single element. Now, passing [feature]
+ # will return a dict structure {"feature": feature}, and passing a single string will return a single element.
+ feature_cols = feature_cols[0] if len(feature_cols) == 1 else feature_cols
+ label_cols = label_cols[0] if len(label_cols) == 1 else label_cols
+
+ if drop_remainder is None:
+ drop_remainder = shuffle
+ tf_dataset = dataset.to_tf_dataset(
+ columns=feature_cols,
+ label_cols=label_cols,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ drop_remainder=drop_remainder,
+ collate_fn=collate_fn,
+ collate_fn_args=collate_fn_args,
+ prefetch=prefetch,
+ )
+ return tf_dataset
+
+ def compile(
+ self,
+ optimizer="rmsprop",
+ loss="auto_with_warning",
+ metrics=None,
+ loss_weights=None,
+ weighted_metrics=None,
+ run_eagerly=None,
+ steps_per_execution=None,
+ **kwargs,
+ ):
+ """
+ This is a thin wrapper that sets the model's loss output head as the loss if the user does not specify a loss
+ function themselves.
+ """
+ if loss in ("auto_with_warning", "passthrough"): # "passthrough" for workflow backward compatibility
+ logger.info(
+ "No loss specified in compile() - the model's internal loss computation will be used as the "
+ "loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
+ "To disable this behaviour please pass a loss argument, or explicitly pass "
+ "`loss=None` if you do not want your model to compute a loss. You can also specify `loss='auto'` to "
+ "get the internal loss without printing this info string."
+ )
+ loss = "auto"
+ if loss == "auto":
+ loss = dummy_loss
+ self._using_dummy_loss = True
+ else:
+ self._using_dummy_loss = False
+ parent_args = list(inspect.signature(keras.Model.compile).parameters.keys())
+ # This argument got renamed, we need to support both versions
+ if "steps_per_execution" in parent_args:
+ super().compile(
+ optimizer=optimizer,
+ loss=loss,
+ metrics=metrics,
+ loss_weights=loss_weights,
+ weighted_metrics=weighted_metrics,
+ run_eagerly=run_eagerly,
+ steps_per_execution=steps_per_execution,
+ **kwargs,
+ )
+ else:
+ super().compile(
+ optimizer=optimizer,
+ loss=loss,
+ metrics=metrics,
+ loss_weights=loss_weights,
+ weighted_metrics=weighted_metrics,
+ run_eagerly=run_eagerly,
+ experimental_steps_per_execution=steps_per_execution,
+ **kwargs,
+ )
+
+ def compute_loss(self, *args, **kwargs):
+ if hasattr(keras.Model, "compute_loss"):
+ # This will be true in TF 2.8 or greater
+ return super().compute_loss(*args, **kwargs)
+ else:
+ warnings.warn(
+ "The old compute_loss method is deprecated as it conflicts with the Keras compute_loss "
+ "method added in TF 2.8. If you want the original HF compute_loss, please call "
+ "hf_compute_loss() instead. From TF versions >= 2.8, or Transformers versions >= 5, "
+ "calling compute_loss() will get the Keras method instead.",
+ FutureWarning,
+ )
+ return self.hf_compute_loss(*args, **kwargs)
+
+ def get_label_to_output_name_mapping(self):
+ arg_names = list(inspect.signature(self.call).parameters)
+ if self._label_to_output_map is not None:
+ return self._label_to_output_map
+ elif "start_positions" in arg_names:
+ return {"start_positions": "start_logits", "end_positions": "end_logits"}
+ elif "sentence_order_label" in arg_names:
+ return {"labels": "prediction_logits", "sentence_order_label": "sop_logits"}
+ elif "next_sentence_label" in arg_names:
+ return {"labels": "prediction_logits", "next_sentence_label": "seq_relationship_logits"}
+ elif "mc_labels" in arg_names:
+ return {"labels": "logits", "mc_labels": "mc_logits"}
+ else:
+ return {}
+
+ def train_step(self, data):
+ """
+ A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
+ and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
+ labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
+ that they are available to the model during the forward pass.
+ """
+
+ # We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
+ arg_names = list(inspect.signature(self.call).parameters)
+ label_kwargs = find_labels(self.__class__)
+ label_to_output = self.get_label_to_output_name_mapping()
+ output_to_label = {val: key for key, val in label_to_output.items()}
+ if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"):
+ # Newer TF train steps leave this out
+ data = expand_1d(data)
+ x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
+ # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
+ # them during input/label pre-processing. This avoids surprising the user by wrecking their data.
+ # In addition, modifying mutable Python inputs makes XLA compilation impossible.
+ if isinstance(x, dict):
+ x = x.copy()
+ if isinstance(y, dict):
+ y = y.copy()
+
+ # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
+ # if those keys are not already present in the input dict
+ if self._using_dummy_loss and y is not None:
+ # If y is a tensor and the model only has one label-like input, map y to that input
+ if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
+ if isinstance(x, tf.Tensor):
+ x = {arg_names[0]: x}
+ label_kwarg = next(iter(label_kwargs))
+ if label_kwarg not in x:
+ x[label_kwarg] = y
+ # Otherwise, copy keys from y to x as long as they weren't already present in x
+ elif isinstance(y, dict):
+ if isinstance(x, tf.Tensor):
+ x = {arg_names[0]: x}
+ for key, val in y.items():
+ if key in arg_names and key not in x:
+ x[key] = val
+ elif output_to_label.get(key) in arg_names and key not in x:
+ x[output_to_label[key]] = val
+ if y is None:
+ y = {key: val for key, val in x.items() if key in label_kwargs}
+ if not y and not self._using_dummy_loss:
+ raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!")
+
+ if isinstance(y, dict):
+ # Rename labels at this point to match output heads
+ y = {label_to_output.get(key, key): val for key, val in y.items()}
+
+ # Run forward pass.
+ with tf.GradientTape() as tape:
+ if self._using_dummy_loss and "return_loss" in arg_names:
+ y_pred = self(x, training=True, return_loss=True)
+ else:
+ y_pred = self(x, training=True)
+ if self._using_dummy_loss:
+ loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
+ else:
+ loss = None
+
+ # This next block matches outputs to label keys. Tensorflow's standard method for doing this
+ # can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
+ if isinstance(y, dict) and len(y) == 1:
+ if list(y.keys())[0] in y_pred:
+ y_pred = y_pred[list(y.keys())[0]]
+ elif list(y_pred.keys())[0] == "loss":
+ y_pred = y_pred[1]
+ else:
+ y_pred = y_pred[0]
+ _, y = y.popitem()
+ elif isinstance(y, dict):
+ # If the labels are a dict, match keys from the output by name
+ y_pred = {key: val for key, val in y_pred.items() if key in y}
+ elif isinstance(y, (tuple, list)):
+ # If the labels are a tuple/list, match keys to the output by order, skipping the loss.
+ if list(y_pred.keys())[0] == "loss":
+ y_pred = y_pred.to_tuple()[1:]
+ else:
+ y_pred = y_pred.to_tuple()
+ y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems
+ else:
+ # If the labels are a single tensor, match them to the first non-loss tensor in the output
+ if list(y_pred.keys())[0] == "loss":
+ y_pred = y_pred[1]
+ else:
+ y_pred = y_pred[0]
+
+ if loss is None:
+ loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
+
+ # Run backwards pass.
+ self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
+
+ self.compiled_metrics.update_state(y, y_pred, sample_weight)
+ # Collect metrics to return
+ return_metrics = {}
+ for metric in self.metrics:
+ result = metric.result()
+ if isinstance(result, dict):
+ return_metrics.update(result)
+ else:
+ return_metrics[metric.name] = result
+ return return_metrics
+
+ def test_step(self, data):
+ """
+ A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
+ and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
+ labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
+ that they are available to the model during the forward pass.
+ """
+ # We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
+ arg_names = list(inspect.signature(self.call).parameters)
+ label_kwargs = find_labels(self.__class__)
+ label_to_output = self.get_label_to_output_name_mapping()
+ output_to_label = {val: key for key, val in label_to_output.items()}
+ if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"):
+ # Newer versions leave this out
+ data = expand_1d(data)
+ x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
+ # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
+ # them during input/label pre-processing. This avoids surprising the user by wrecking their data.
+ # In addition, modifying mutable Python inputs makes XLA compilation impossible.
+ if isinstance(x, dict):
+ x = x.copy()
+ if isinstance(y, dict):
+ y = y.copy()
+
+ # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
+ # if those keys are not already present in the input dict
+ if self._using_dummy_loss and y is not None:
+ arg_names = list(inspect.signature(self.call).parameters)
+ # If y is a tensor and the model only has one label-like input, map y to that input
+ if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
+ if isinstance(x, tf.Tensor):
+ x = {arg_names[0]: x}
+ label_kwarg = next(iter(label_kwargs))
+ if label_kwarg not in x:
+ x[label_kwarg] = y
+ # Otherwise, copy keys from y to x as long as they weren't already present in x
+ elif isinstance(y, dict):
+ if isinstance(x, tf.Tensor):
+ x = {arg_names[0]: x}
+ for key, val in y.items():
+ if key in arg_names and key not in x:
+ x[key] = val
+ elif output_to_label.get(key) in arg_names and key not in x:
+ x[output_to_label[key]] = val
+ if y is None:
+ y = {key: val for key, val in x.items() if key in label_kwargs}
+ if not y and not self._using_dummy_loss:
+ raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!")
+
+ if isinstance(y, dict):
+ # Rename labels at this point to match output heads
+ y = {label_to_output.get(key, key): val for key, val in y.items()}
+
+ # Run forward pass.
+ if self._using_dummy_loss and "return_loss" in arg_names:
+ y_pred = self(x, return_loss=True, training=False)
+ else:
+ y_pred = self(x, training=False)
+ if self._using_dummy_loss:
+ loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
+ else:
+ loss = None
+
+ # This next block matches outputs to label keys. Tensorflow's standard method for doing this
+ # can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
+ if isinstance(y, dict) and len(y) == 1:
+ if list(y.keys())[0] in y_pred:
+ y_pred = y_pred[list(y.keys())[0]]
+ elif list(y_pred.keys())[0] == "loss":
+ y_pred = y_pred[1]
+ else:
+ y_pred = y_pred[0]
+ _, y = y.popitem()
+ elif isinstance(y, dict):
+ # If the labels are a dict, match keys from the output by name
+ y_pred = {key: val for key, val in y_pred.items() if key in y}
+ elif isinstance(y, (tuple, list)):
+ # If the labels are a tuple/list, match keys to the output by order, skipping the loss.
+ if list(y_pred.keys())[0] == "loss":
+ y_pred = y_pred.to_tuple()[1:]
+ else:
+ y_pred = y_pred.to_tuple()
+ y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems
+ else:
+ # If the labels are a single tensor, match them to the first non-loss tensor in the output
+ if list(y_pred.keys())[0] == "loss":
+ y_pred = y_pred[1]
+ else:
+ y_pred = y_pred[0]
+
+ if loss is None:
+ loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
+
+ self.compiled_metrics.update_state(y, y_pred, sample_weight)
+ # Collect metrics to return
+ return_metrics = {}
+ for metric in self.metrics:
+ result = metric.result()
+ if isinstance(result, dict):
+ return_metrics.update(result)
+ else:
+ return_metrics[metric.name] = result
+ return return_metrics
+
+ def create_model_card(
+ self,
+ output_dir,
+ model_name: str,
+ language: str | None = None,
+ license: str | None = None,
+ tags: str | None = None,
+ finetuned_from: str | None = None,
+ tasks: str | None = None,
+ dataset_tags: str | list[str] | None = None,
+ dataset: str | list[str] | None = None,
+ dataset_args: str | list[str] | None = None,
+ ):
+ """
+ Creates a draft of a model card using the information available to the `Trainer`.
+
+ Args:
+ output_dir (`str` or `os.PathLike`):
+ The folder in which to create the model card.
+ model_name (`str`, *optional*):
+ The name of the model.
+ language (`str`, *optional*):
+ The language of the model (if applicable)
+ license (`str`, *optional*):
+ The license of the model. Will default to the license of the pretrained model used, if the original
+ model given to the `Trainer` comes from a repo on the Hub.
+ tags (`str` or `list[str]`, *optional*):
+ Some tags to be included in the metadata of the model card.
+ finetuned_from (`str`, *optional*):
+ The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
+ of the original model given to the `Trainer` (if it comes from the Hub).
+ tasks (`str` or `list[str]`, *optional*):
+ One or several task identifiers, to be included in the metadata of the model card.
+ dataset_tags (`str` or `list[str]`, *optional*):
+ One or several dataset tags, to be included in the metadata of the model card.
+ dataset (`str` or `list[str]`, *optional*):
+ One or several dataset identifiers, to be included in the metadata of the model card.
+ dataset_args (`str` or `list[str]`, *optional*):
+ One or several dataset arguments, to be included in the metadata of the model card.
+ """
+ # Avoids a circular import by doing this when necessary.
+ from .modelcard import TrainingSummary # tests_ignore
+
+ training_summary = TrainingSummary.from_keras(
+ self,
+ keras_history=self.history,
+ language=language,
+ license=license,
+ tags=tags,
+ model_name=model_name,
+ finetuned_from=finetuned_from,
+ tasks=tasks,
+ dataset_tags=dataset_tags,
+ dataset=dataset,
+ dataset_args=dataset_args,
+ )
+ model_card = training_summary.to_model_card()
+ with open(os.path.join(output_dir, "README.md"), "w") as f:
+ f.write(model_card)
+
+ def set_input_embeddings(self, value):
+ """
+ Set model's input embeddings
+
+ Args:
+ value (`tf.Variable`):
+ The new weights mapping hidden states to vocabulary.
+ """
+ main_layer = getattr(self, self.base_model_prefix)
+
+ if main_layer is None:
+ raise NotImplementedError("The model does not implements the base_model_prefix attribute.")
+
+ try:
+ main_layer.set_input_embeddings(value)
+ except AttributeError:
+ logger.info("Building the model")
+ self.build_in_name_scope()
+ main_layer.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> None | keras.layers.Layer:
+ """
+ Returns the model's output embeddings
+
+ Returns:
+ `tf.Variable`: The new weights mapping vocabulary to hidden states.
+ """
+ if self.get_lm_head() is not None:
+ lm_head = self.get_lm_head()
+
+ try:
+ return lm_head.get_output_embeddings()
+ except AttributeError:
+ logger.info("Building the model")
+ self.build_in_name_scope()
+
+ return lm_head().get_output_embeddings()
+
+ return None # Overwrite for models with output embeddings
+
+ def set_output_embeddings(self, value):
+ """
+ Set model's output embeddings
+
+ Args:
+ value (`tf.Variable`):
+ The new weights mapping hidden states to vocabulary.
+ """
+ if self.get_lm_head() is not None:
+ lm_head = self.get_lm_head()
+ try:
+ lm_head.set_output_embeddings(value)
+ except AttributeError:
+ logger.info("Building the model")
+ self.build_in_name_scope()
+ lm_head.set_output_embeddings(value)
+
+ def get_output_layer_with_bias(self) -> None | keras.layers.Layer:
+ """
+ Get the layer that handles a bias attribute in case the model has an LM head with weights tied to the
+ embeddings
+
+ Return:
+ `keras.layers.Layer`: The layer that handles the bias, None if not an LM model.
+ """
+ warnings.warn(
+ "The method get_output_layer_with_bias is deprecated. Please use `get_lm_head` instead.", FutureWarning
+ )
+ return self.get_lm_head()
+
+ def get_prefix_bias_name(self) -> None | str:
+ """
+ Get the concatenated _prefix name of the bias from the model name to the parent layer
+
+ Return:
+ `str`: The _prefix name of the bias.
+ """
+ warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
+ return None
+
+ def get_bias(self) -> None | dict[str, tf.Variable]:
+ """
+ Dict of bias attached to an LM head. The key represents the name of the bias attribute.
+
+ Return:
+ `tf.Variable`: The weights representing the bias, None if not an LM model.
+ """
+ if self.get_lm_head() is not None:
+ lm_head = self.get_lm_head()
+ try:
+ return lm_head.get_bias()
+ except AttributeError:
+ self.build_in_name_scope()
+
+ return lm_head.get_bias()
+ return None
+
+ def set_bias(self, value):
+ """
+ Set all the bias in the LM head.
+
+ Args:
+ value (`dict[tf.Variable]`):
+ All the new bias attached to an LM head.
+ """
+ if self.get_lm_head() is not None:
+ lm_head = self.get_lm_head()
+ try:
+ lm_head.set_bias(value)
+ except AttributeError:
+ self.build_in_name_scope()
+ lm_head.set_bias(value)
+
+ def get_lm_head(self) -> keras.layers.Layer:
+ """
+ The LM Head layer. This method must be overwritten by all the models that have a lm head.
+
+ Return:
+ `keras.layers.Layer`: The LM head layer if the model has one, None if not.
+ """
+ return None
+
+ def resize_token_embeddings(self, new_num_tokens: int | None = None) -> keras.layers.Embedding | tf.Variable:
+ """
+ Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
+
+ Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
+
+ Arguments:
+ new_num_tokens (`int`, *optional*):
+ The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
+ vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
+ returns a pointer to the input tokens without doing anything.
+
+ Return:
+ `tf.Variable` or `keras.layers.Embedding`: Pointer to the input tokens of the model.
+ """
+ # TODO (joao): flagged for replacement (by `_v2_resized_token_embeddings`) due to embeddings refactor
+
+ # Run the new code path if the model has a keras embeddings layer
+ if isinstance(self.get_input_embeddings(), keras.layers.Embedding):
+ return self._v2_resized_token_embeddings(new_num_tokens)
+
+ if new_num_tokens is None or new_num_tokens == self.config.vocab_size:
+ return self._get_word_embedding_weight(self.get_input_embeddings())
+
+ model_embeds = self._resize_token_embeddings(new_num_tokens)
+
+ # Update base model and current model config
+ self.config.vocab_size = new_num_tokens
+
+ return model_embeds
+
+ def _v2_resized_token_embeddings(self, new_num_tokens: int | None = None) -> keras.layers.Embedding:
+ """
+ Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
+
+ Arguments:
+ new_num_tokens (`int`, *optional*):
+ The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
+ vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
+ returns a pointer to the input tokens without doing anything.
+
+ Return:
+ `keras.layers.Embedding`: Pointer to the input tokens of the model.
+ """
+ if new_num_tokens is None or new_num_tokens == self.config.vocab_size:
+ return self.get_input_embeddings()
+
+ model_embeds = self._v2_resize_token_embeddings(new_num_tokens)
+
+ # Update base model and current model config
+ self.config.vocab_size = new_num_tokens
+
+ return model_embeds
+
+ def _get_word_embedding_weight(model, embedding_layer):
+ # TODO (joao): flagged for detection due to embeddings refactor
+
+ # If the variable holds the weights themselves, return them
+ if isinstance(embedding_layer, tf.Tensor):
+ return embedding_layer
+ # Otherwise, try to get them from the layer's attributes
+
+ embeds = getattr(embedding_layer, "weight", None)
+ if embeds is not None:
+ return embeds
+
+ embeds = getattr(embedding_layer, "decoder", None)
+ if embeds is not None:
+ return embeds
+
+ # The reason why the attributes don't exist might be
+ # because the model is not built, so retry getting
+ # the argument after building the model
+ model.build_in_name_scope()
+
+ embeds = getattr(embedding_layer, "weight", None)
+ if embeds is not None:
+ return embeds
+
+ embeds = getattr(embedding_layer, "decoder", None)
+ if embeds is not None:
+ return embeds
+
+ return None
+
+ def _resize_token_embeddings(self, new_num_tokens):
+ # TODO (joao): flagged for replacement (by `_v2_resize_token_embeddings`) due to embeddings refactor
+ old_embeddings = self._get_word_embedding_weight(self.get_input_embeddings())
+ new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
+
+ # if word embeddings are not tied, make sure that lm head bias is resized as well
+ if self.get_bias() is not None:
+ old_lm_head_bias = self.get_bias()
+ new_lm_head_bias = self._get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens)
+
+ self.set_bias(new_lm_head_bias)
+
+ # if word embeddings are not tied, make sure that lm head decoder is resized as well
+ if self.get_output_embeddings() is not None:
+ old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings())
+ new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens)
+
+ self.set_output_embeddings(new_lm_head_decoder)
+
+ self.set_input_embeddings(new_embeddings)
+
+ return self.get_input_embeddings()
+
+ def _v2_resize_token_embeddings(self, new_num_tokens):
+ old_embeddings = self.get_input_embeddings()
+ new_embeddings = self._v2_get_resized_embeddings(old_embeddings, new_num_tokens)
+ self.set_input_embeddings(new_embeddings)
+
+ # If word embeddings are not tied, make sure that lm head bias is resized as well
+ if self.get_bias() is not None:
+ old_lm_head_bias = self.get_bias()
+ new_lm_head_bias = self._v2_get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens)
+ self.set_bias(new_lm_head_bias)
+
+ # If word embeddings are not tied, make sure that lm head decoder is resized as well.
+ tied_weights = self.get_input_embeddings() == self.get_output_embeddings()
+ if self.get_output_embeddings() is not None and not tied_weights:
+ old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings())
+ # TODO (joao): this one probably needs a v2 version with other models
+ new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens)
+ self.set_output_embeddings(new_lm_head_decoder)
+
+ return self.get_input_embeddings()
+
+ def _get_resized_lm_head_bias(self, old_lm_head_bias, new_num_tokens):
+ """
+ Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end.
+ Reducing the size will remove vectors from the end
+
+ Args:
+ old_lm_head_bias (`tf.Variable`):
+ Old lm head bias to be resized.
+ new_num_tokens (`int`, *optional*):
+ New number of tokens in the linear matrix.
+
+ Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
+ vectors from the end. If not provided or `None`, just returns None
+
+ Return:
+ `tf.Variable`: Pointer to the resized bias.
+ """
+ # TODO (joao): flagged for replacement (by `_v2_get_resized_lm_head_bias`) due to embeddings refactor
+ new_lm_head_bias = {}
+
+ for attr, weight in old_lm_head_bias.items():
+ first_dim, old_num_tokens = (None, shape_list(weight)[0]) if tf.rank(weight) == 1 else shape_list(weight)
+ size_diff = new_num_tokens - old_num_tokens
+ final_shape = [new_num_tokens] if first_dim is None else [first_dim, new_num_tokens]
+
+ # initialize new bias
+ if tf.math.greater(size_diff, 0):
+ padding_shape = [[0, size_diff]] if first_dim is None else [[0, 0], [0, size_diff]]
+ current_bias = tf.pad(weight.value(), tf.convert_to_tensor(padding_shape), constant_values=-1)
+ num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
+ mask_shape = [num_tokens_to_copy] if first_dim is None else [1, num_tokens_to_copy]
+ bias_mask = tf.fill(tf.convert_to_tensor(mask_shape), True)
+ bias_mask = tf.pad(bias_mask, tf.convert_to_tensor(padding_shape), constant_values=False)
+ else:
+ slice_from = [0] if first_dim is None else [0, 0]
+ current_bias = tf.slice(
+ weight.value(), tf.convert_to_tensor(slice_from), tf.convert_to_tensor(final_shape)
+ )
+ bias_mask = tf.fill(tf.convert_to_tensor(final_shape), True)
+
+ new_bias = self.add_weight(
+ shape=final_shape,
+ initializer="zeros",
+ trainable=True,
+ name=weight.name.split(":")[0],
+ )
+ init_bias = tf.where(bias_mask, current_bias, new_bias.value())
+
+ new_bias.assign(init_bias)
+ new_lm_head_bias[attr] = new_bias
+
+ return new_lm_head_bias
+
+ def _v2_get_resized_lm_head_bias(
+ self, old_lm_head_bias: dict[str, tf.Variable], new_num_tokens: int
+ ) -> dict[str, tf.Tensor]:
+ """
+ Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end.
+ Reducing the size will remove vectors from the end
+
+ Args:
+ old_lm_head_bias (`dict[str, tf.Variable]`):
+ Old lm head bias to be resized.
+ new_num_tokens (`int`):
+ New number of tokens in the linear matrix. Increasing the size will add newly initialized vectors at
+ the end. Reducing the size will remove vectors from the end.
+
+ Return:
+ `tf.Tensor`: Values for the resized bias.
+ """
+ new_lm_head_bias = {}
+
+ for attr, weight in old_lm_head_bias.items():
+ # Determine the size difference (depending on the shape)
+ first_dim, old_num_tokens = (None, shape_list(weight)[0]) if tf.rank(weight) == 1 else shape_list(weight)
+ size_diff = new_num_tokens - old_num_tokens
+
+ # Copy the old bias values to the new bias
+ if old_num_tokens > new_num_tokens:
+ new_bias = weight.value()[..., :new_num_tokens]
+ else:
+ padding_shape = [[0, size_diff]] if first_dim is None else [[0, 0], [0, size_diff]]
+ new_bias = tf.pad(weight.value(), tf.convert_to_tensor(padding_shape))
+
+ new_lm_head_bias[attr] = new_bias
+ return new_lm_head_bias
+
+ def _get_resized_lm_head_decoder(self, old_lm_head_decoder, new_num_tokens):
+ """
+ Build a resized decoder from the old ones. Increasing the size will add newly initialized vectors at the end.
+ Reducing the size will remove vectors from the end
+
+ Args:
+ old_lm_head_decoder (`tf.Variable`):
+ Old lm head decoder to be resized.
+ new_num_tokens (`int`, *optional*):
+ New number of tokens in the linear matrix.
+
+ Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
+ vectors from the end. If not provided or `None`, just returns None
+
+ Return:
+ `tf.Variable`: Pointer to the resized decoder or None if the output embeddings are different from the input
+ ones.
+ """
+ new_lm_head_decoder = old_lm_head_decoder
+ is_input_output_equals = tf.reduce_any(
+ self._get_word_embedding_weight(self.get_input_embeddings()) == old_lm_head_decoder
+ )
+
+ if old_lm_head_decoder is not None and not is_input_output_equals:
+ old_embedding_dim = shape_list(old_lm_head_decoder)[1]
+ decoder_mask, current_decoder = init_copy_embeddings(old_lm_head_decoder, new_num_tokens)
+ new_lm_head_decoder = self.add_weight(
+ shape=(new_num_tokens, old_embedding_dim),
+ initializer="zeros",
+ trainable=True,
+ name=old_lm_head_decoder.name.split(":")[0],
+ )
+ init_decoder = tf.where(decoder_mask, current_decoder, new_lm_head_decoder.value())
+
+ new_lm_head_decoder.assign(init_decoder)
+
+ return new_lm_head_decoder
+
+ def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable:
+ """
+ Build a resized Embedding weights from a provided token Embedding weights. Increasing the size will add newly
+ initialized vectors at the end. Reducing the size will remove vectors from the end
+
+ Args:
+ old_embeddings (`tf.Variable`):
+ Old embeddings to be resized.
+ new_num_tokens (`int`, *optional*):
+ New number of tokens in the embedding matrix.
+
+ Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
+ vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
+ `tf.Variable` module of the model without doing anything.
+
+ Return:
+ `tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if `new_num_tokens` is
+ `None`
+ """
+ # TODO (joao): flagged for replacement (by `_v2_get_resized_embeddings`) due to embeddings refactor
+ old_embedding_dim = shape_list(old_embeddings)[1]
+ init_range = getattr(self.config, "initializer_range", 0.02)
+ embeddings_mask, current_embeddings = init_copy_embeddings(old_embeddings, new_num_tokens)
+ new_embeddings = self.add_weight(
+ name=old_embeddings.name.split(":")[0],
+ shape=[new_num_tokens, old_embedding_dim],
+ initializer=get_initializer(init_range),
+ dtype=tf.float32,
+ )
+ init_embeddings = tf.where(embeddings_mask, current_embeddings, new_embeddings.value())
+
+ new_embeddings.assign(init_embeddings)
+
+ return new_embeddings
+
+ def _v2_get_resized_embeddings(
+ self, old_embeddings: keras.layers.Embedding, new_num_tokens: int
+ ) -> keras.layers.Embedding:
+ """
+ Build a resized Embedding layer from a provided Embedding layer. Increasing the size will add newly initialized
+ vectors at the end. Reducing the size will remove vectors from the end.
+
+ Args:
+ old_embeddings (`keras.layers.Embedding`):
+ Old embeddings to be resized.
+ new_num_tokens (`int`, *optional*):
+ New number of tokens in the embedding matrix.
+
+ Return:
+ `keras.layers.Embedding`: Resized Embedding layer.
+ """
+
+ # Get the initialization range for the embeddings
+ init_range = 0.02 # default value
+ potential_initialization_variable_names = [
+ "initializer_range", # most common
+ "initializer_factor", # e.g. T5
+ "init_std", # e.g BART
+ ]
+ for var_name in potential_initialization_variable_names:
+ if hasattr(self.config, var_name):
+ init_range = getattr(self.config, var_name)
+
+ # Get a new (initialized) embeddings layer
+ new_embeddings = keras.layers.Embedding(
+ input_dim=new_num_tokens,
+ output_dim=old_embeddings.output_dim,
+ embeddings_initializer=keras.initializers.TruncatedNormal(stddev=init_range),
+ name=old_embeddings.embeddings.name[:-13], # exact same scoped name except "/embeddings:0"
+ )
+ new_embeddings(tf.constant([[0]]))
+
+ # Copy the old embeddings to the new embeddings
+ if old_embeddings.input_dim >= new_num_tokens:
+ init_embeddings = old_embeddings.embeddings[:new_num_tokens]
+ else:
+ init_embeddings = tf.concat(
+ [old_embeddings.embeddings, new_embeddings.embeddings[old_embeddings.input_dim :]], axis=0
+ )
+ new_embeddings.embeddings.assign(init_embeddings)
+ return new_embeddings
+
+ def prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the base model.
+
+ Arguments:
+ heads_to_prune (`dict[int, list[int]]`):
+ Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads
+ to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on
+ layer 1 and heads 2 and 3 on layer 2.
+ """
+ raise NotImplementedError
+
+ def save_pretrained(
+ self,
+ save_directory,
+ saved_model=False,
+ version=1,
+ push_to_hub=False,
+ signatures=None,
+ max_shard_size: int | str = "5GB",
+ create_pr: bool = False,
+ safe_serialization: bool = False,
+ token: str | bool | None = None,
+ **kwargs,
+ ):
+ """
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
+ [`~TFPreTrainedModel.from_pretrained`] class method.
+
+ Arguments:
+ save_directory (`str`):
+ Directory to which to save. Will be created if it doesn't exist.
+ saved_model (`bool`, *optional*, defaults to `False`):
+ If the model has to be saved in saved model format as well or not.
+ version (`int`, *optional*, defaults to 1):
+ The version of the saved model. A saved model needs to be versioned in order to be properly loaded by
+ TensorFlow Serving as detailed in the official documentation
+ https://www.tensorflow.org/tfx/serving/serving_basic
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ signatures (`dict` or `tf.function`, *optional*):
+ Model's signature used for serving. This will be passed to the `signatures` argument of model.save().
+ max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
+ The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
+ lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
+
+
+
+ If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
+ which will be bigger than `max_shard_size`.
+
+
+
+ create_pr (`bool`, *optional*, defaults to `False`):
+ Whether or not to create a PR with the uploaded files or directly commit.
+ safe_serialization (`bool`, *optional*, defaults to `False`):
+ Whether to save the model using `safetensors` or the traditional TensorFlow way (that uses `h5`).
+ token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+ the token generated when running `hf auth login` (stored in `~/.huggingface`).
+ kwargs (`dict[str, Any]`, *optional*):
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ """
+ use_auth_token = kwargs.pop("use_auth_token", None)
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ if token is not None:
+ kwargs["token"] = token
+
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ if push_to_hub:
+ commit_message = kwargs.pop("commit_message", None)
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+ repo_id = self._create_repo(repo_id, **kwargs)
+ files_timestamps = self._get_files_timestamps(save_directory)
+
+ if saved_model:
+ # If `torch_dtype` is in the config with a torch dtype class as the value, we need to change it to string.
+ # (Although TF doesn't care about this attribute, we can't just remove it or set it to `None`.)
+ if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str):
+ self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1]
+ if signatures is None:
+ serving_default = self.serving.get_concrete_function(self.input_signature)
+ if any(spec.dtype == tf.int32 for spec in self.input_signature.values()):
+ int64_spec = {
+ key: tf.TensorSpec(
+ shape=spec.shape, dtype=tf.int64 if spec.dtype == tf.int32 else spec.dtype, name=spec.name
+ )
+ for key, spec in self.input_signature.items()
+ }
+ int64_serving = self.serving.get_concrete_function(int64_spec)
+ signatures = {"serving_default": serving_default, "int64_serving": int64_serving}
+ else:
+ signatures = serving_default
+ saved_model_dir = os.path.join(save_directory, "saved_model", str(version))
+ self.save(saved_model_dir, include_optimizer=False, signatures=signatures)
+ logger.info(f"Saved model created in {saved_model_dir}")
+
+ # Save configuration file
+ self.config.architectures = [self.__class__.__name__[2:]]
+
+ # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
+ # loaded from the Hub.
+ if self._auto_class is not None:
+ custom_object_save(self, save_directory, config=self.config)
+
+ self.config.save_pretrained(save_directory)
+ if self.can_generate():
+ self.generation_config.save_pretrained(save_directory)
+
+ # If we save using the predefined names, we can load using `from_pretrained`
+ weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME
+ output_model_file = os.path.join(save_directory, weights_name)
+
+ shards, index = tf_shard_checkpoint(self.weights, max_shard_size, weights_name=weights_name)
+
+ # Clean the folder from a previous save
+ for filename in os.listdir(save_directory):
+ full_filename = os.path.join(save_directory, filename)
+ # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
+ # in distributed settings to avoid race conditions.
+ weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
+ if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and filename not in shards:
+ os.remove(full_filename)
+
+ if index is None:
+ if safe_serialization:
+ state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in self.weights}
+ safe_save_file(state_dict, output_model_file, metadata={"format": "tf"})
+ else:
+ self.save_weights(output_model_file)
+ logger.info(f"Model weights saved in {output_model_file}")
+ else:
+ save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else TF2_WEIGHTS_INDEX_NAME
+ save_index_file = os.path.join(save_directory, save_index_file)
+ # Save the index as well
+ with open(save_index_file, "w", encoding="utf-8") as index_file:
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
+ index_file.write(content)
+ logger.info(
+ f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
+ f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+ for shard_file, shard in shards.items():
+ if safe_serialization:
+ shard_state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in shard}
+ safe_save_file(
+ shard_state_dict, os.path.join(save_directory, shard_file), metadata={"format": "tf"}
+ )
+ else:
+ with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file:
+ layers = []
+ for layer in sorted(shard, key=lambda x: x.name):
+ if "model." in layer.name or len(layer.name.split("/")) == 1:
+ layer_name = layer.name
+ else:
+ layer_name = "/".join(layer.name.split("/")[1:])
+ param_dset = shard_file.create_dataset(
+ layer_name, layer.numpy().shape, dtype=layer.numpy().dtype
+ )
+ param_dset[:] = layer.numpy()
+ layers.append(layer_name.encode("utf8"))
+ save_attributes_to_hdf5_group(shard_file, "layer_names", layers)
+
+ if push_to_hub:
+ self._upload_modified_files(
+ save_directory,
+ repo_id,
+ files_timestamps,
+ commit_message=commit_message,
+ token=token,
+ )
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: str | os.PathLike | None,
+ *model_args,
+ config: PretrainedConfig | str | os.PathLike | None = None,
+ cache_dir: str | os.PathLike | None = None,
+ ignore_mismatched_sizes: bool = False,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ token: str | bool | None = None,
+ revision: str = "main",
+ use_safetensors: bool | None = None,
+ **kwargs,
+ ):
+ r"""
+ Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_name_or_path (`str`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this
+ case, `from_pt` should be set to `True` and a configuration object should be provided as `config`
+ argument. This loading path is slower than converting the PyTorch model in a TensorFlow model
+ using the provided conversion scripts and loading the TensorFlow model afterwards.
+ - `None` if you are both providing the configuration and state dictionary (resp. with keyword
+ arguments `config` and `state_dict`).
+ model_args (sequence of positional arguments, *optional*):
+ All remaining positional arguments will be passed to the underlying model's `__init__` method.
+ config (`Union[PretrainedConfig, str]`, *optional*):
+ Can be either:
+
+ - an instance of a class derived from [`PretrainedConfig`],
+ - a string valid as input to [`~PretrainedConfig.from_pretrained`].
+
+ Configuration for the model to use instead of an automatically loaded configuration. Configuration can
+ be automatically loaded when:
+
+ - The model is a model provided by the library (loaded with the *model id* string of a pretrained
+ model).
+ - The model was saved using [`~TFPreTrainedModel.save_pretrained`] and is reloaded by supplying the
+ save directory.
+ - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
+ configuration JSON file named *config.json* is found in the directory.
+ from_pt (`bool`, *optional*, defaults to `False`):
+ Load the model weights from a PyTorch state_dict save file (see docstring of
+ `pretrained_model_name_or_path` argument).
+ ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
+ Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
+ as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
+ checkpoint with 3 labels).
+ cache_dir (`str`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies:
+ (`dict[str, str], `optional`): A dictionary of proxy servers to use by protocol or endpoint, e.g.,
+ `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`): Whether ot not to also return a
+ dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (e.g., not try downloading the model).
+ token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+ the token generated when running `hf auth login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+
+
+
+
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`.
+
+
+
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
+ tf_to_pt_weight_rename (`Callable`, *optional*):
+ A function that is called to transform the names of weights during the PyTorch to TensorFlow
+ crossloading process. This is not necessary for most models, but is useful to allow composite models to
+ be crossloaded correctly.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
+ is not installed, it will be set to `False`.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+ `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
+ automatically loaded:
+
+ - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
+ underlying model's `__init__` method (we assume all relevant updates to the configuration have
+ already been done)
+ - If a configuration is not provided, `kwargs` will be first passed to the configuration class
+ initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
+ corresponds to a configuration attribute will be used to override said attribute with the
+ supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
+ will be passed to the underlying model's `__init__` function.
+
+ Examples:
+
+ ```python
+ >>> from transformers import BertConfig, TFBertModel
+
+ >>> # Download model and configuration from huggingface.co and cache.
+ >>> model = TFBertModel.from_pretrained("google-bert/bert-base-uncased")
+ >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
+ >>> model = TFBertModel.from_pretrained("./test/saved_model/")
+ >>> # Update configuration during loading.
+ >>> model = TFBertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True)
+ >>> assert model.config.output_attentions == True
+ >>> # Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable).
+ >>> config = BertConfig.from_json_file("./pt_model/my_pt_model_config.json")
+ >>> model = TFBertModel.from_pretrained("./pt_model/my_pytorch_model.bin", from_pt=True, config=config)
+ ```"""
+ from_pt = kwargs.pop("from_pt", False)
+ resume_download = kwargs.pop("resume_download", None)
+ proxies = kwargs.pop("proxies", None)
+ output_loading_info = kwargs.pop("output_loading_info", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
+ _ = kwargs.pop("mirror", None)
+ load_weight_prefix = kwargs.pop("load_weight_prefix", None)
+ from_pipeline = kwargs.pop("_from_pipeline", None)
+ from_auto_class = kwargs.pop("_from_auto", False)
+ subfolder = kwargs.pop("subfolder", "")
+ commit_hash = kwargs.pop("_commit_hash", None)
+ tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None)
+
+ # Not relevant for TF models
+ _ = kwargs.pop("adapter_kwargs", None)
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ if trust_remote_code is True:
+ logger.warning(
+ "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
+ " ignored."
+ )
+
+ user_agent = {"file_type": "model", "framework": "tensorflow", "from_auto_class": from_auto_class}
+ if from_pipeline is not None:
+ user_agent["using_pipeline"] = from_pipeline
+
+ if is_offline_mode() and not local_files_only:
+ logger.info("Offline mode: forcing local_files_only=True")
+ local_files_only = True
+
+ if use_safetensors is None and not is_safetensors_available():
+ use_safetensors = False
+
+ # Load config if we don't provide a configuration
+ if not isinstance(config, PretrainedConfig):
+ config_path = config if config is not None else pretrained_model_name_or_path
+ config, model_kwargs = cls.config_class.from_pretrained(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ _from_auto=from_auto_class,
+ _from_pipeline=from_pipeline,
+ _commit_hash=commit_hash,
+ **kwargs,
+ )
+ else:
+ model_kwargs = kwargs
+
+ if commit_hash is None:
+ commit_hash = getattr(config, "_commit_hash", None)
+
+ # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
+ # index of the files.
+ is_sharded = False
+ # Load model
+ if pretrained_model_name_or_path is not None:
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if is_local:
+ if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
+ # Load from a PyTorch checkpoint in priority if from_pt
+ archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
+ elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
+ # Load from a sharded PyTorch checkpoint
+ archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
+ is_sharded = True
+ elif use_safetensors is not False and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
+ ):
+ # Load from a safetensors checkpoint
+ archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
+ elif use_safetensors is not False and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
+ ):
+ # Load from a sharded safetensors checkpoint
+ archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
+ is_sharded = True
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
+ # Load from a TF 2.0 checkpoint
+ archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)):
+ # Load from a sharded TF 2.0 checkpoint
+ archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
+ is_sharded = True
+
+ # At this stage we don't have a weight file so we will raise an error.
+ elif use_safetensors:
+ raise OSError(
+ f"Error no file named {SAFE_WEIGHTS_NAME} or {SAFE_WEIGHTS_INDEX_NAME} found in directory {pretrained_model_name_or_path}. "
+ f"Please make sure that the model has been saved with `safe_serialization=True` or do not "
+ f"set `use_safetensors=True`."
+ )
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
+ ):
+ raise OSError(
+ f"Error no file named {TF2_WEIGHTS_NAME} or {SAFE_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
+ "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
+ "weights."
+ )
+ else:
+ raise OSError(
+ f"Error no file named {TF2_WEIGHTS_NAME}, {SAFE_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
+ f"{pretrained_model_name_or_path}."
+ )
+ elif os.path.isfile(pretrained_model_name_or_path):
+ archive_file = pretrained_model_name_or_path
+ is_local = True
+ elif os.path.isfile(pretrained_model_name_or_path + ".index"):
+ archive_file = pretrained_model_name_or_path + ".index"
+ is_local = True
+ elif is_remote_url(pretrained_model_name_or_path):
+ filename = pretrained_model_name_or_path
+ resolved_archive_file = download_url(pretrained_model_name_or_path)
+ else:
+ # set correct filename
+ if from_pt:
+ filename = WEIGHTS_NAME
+ elif use_safetensors is not False:
+ filename = SAFE_WEIGHTS_NAME
+ else:
+ filename = TF2_WEIGHTS_NAME
+
+ try:
+ # Load from URL or cache if already cached
+ cached_file_kwargs = {
+ "cache_dir": cache_dir,
+ "force_download": force_download,
+ "proxies": proxies,
+ "resume_download": resume_download,
+ "local_files_only": local_files_only,
+ "token": token,
+ "user_agent": user_agent,
+ "revision": revision,
+ "subfolder": subfolder,
+ "_raise_exceptions_for_gated_repo": False,
+ "_raise_exceptions_for_missing_entries": False,
+ "_commit_hash": commit_hash,
+ }
+ resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
+
+ # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
+ # result when internet is up, the repo and revision exist, but the file does not.
+ if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME:
+ # Did not find the safetensors file, let's fallback to TF.
+ # No support for sharded safetensors yet, so we'll raise an error if that's all we find.
+ filename = TF2_WEIGHTS_NAME
+ resolved_archive_file = cached_file(
+ pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **cached_file_kwargs
+ )
+ if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME:
+ # Maybe the checkpoint is sharded, we try to grab the index name in this case.
+ resolved_archive_file = cached_file(
+ pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME, **cached_file_kwargs
+ )
+ if resolved_archive_file is not None:
+ is_sharded = True
+ if resolved_archive_file is None and filename == WEIGHTS_NAME:
+ # Maybe the checkpoint is sharded, we try to grab the index name in this case.
+ resolved_archive_file = cached_file(
+ pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
+ )
+ if resolved_archive_file is not None:
+ is_sharded = True
+ if resolved_archive_file is None:
+ # Otherwise, maybe there is a PyTorch or Flax model file. We try those to give a helpful error
+ # message.
+ has_file_kwargs = {
+ "revision": revision,
+ "proxies": proxies,
+ "token": token,
+ "cache_dir": cache_dir,
+ "local_files_only": local_files_only,
+ }
+ if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs):
+ is_sharded = True
+ elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
+ raise OSError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named"
+ f" {TF2_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
+ " load this model from those weights."
+ )
+ else:
+ raise OSError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME},"
+ f" {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}"
+ )
+
+ except OSError:
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
+ # to the original exception.
+ raise
+ except Exception:
+ # For any other exception, we throw a generic error.
+
+ raise OSError(
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
+ " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+ f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}"
+ )
+ if is_local:
+ logger.info(f"loading weights file {archive_file}")
+ resolved_archive_file = archive_file
+ filename = resolved_archive_file.split(os.path.sep)[-1]
+ else:
+ logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
+ else:
+ resolved_archive_file = None
+
+ # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
+ if is_sharded:
+ # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
+ resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
+ pretrained_model_name_or_path,
+ resolved_archive_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ token=token,
+ user_agent=user_agent,
+ revision=revision,
+ _commit_hash=commit_hash,
+ )
+
+ safetensors_from_pt = False
+ if filename == SAFE_WEIGHTS_NAME:
+ with safe_open(resolved_archive_file, framework="tf") as f:
+ safetensors_metadata = f.metadata()
+ if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
+ raise OSError(
+ f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
+ " Make sure you save your model with the `save_pretrained` method."
+ )
+ safetensors_from_pt = safetensors_metadata.get("format") == "pt"
+ elif filename == SAFE_WEIGHTS_INDEX_NAME:
+ with safe_open(resolved_archive_file[0], framework="tf") as f:
+ safetensors_metadata = f.metadata()
+ if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
+ raise OSError(
+ f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
+ " Make sure you save your model with the `save_pretrained` method."
+ )
+ safetensors_from_pt = safetensors_metadata.get("format") == "pt"
+
+ config.name_or_path = pretrained_model_name_or_path
+
+ # composed models, *e.g.* TFRag, require special treatment when it comes to loading
+ # pre-trained weights.
+ if cls._requires_load_weight_prefix and model_kwargs.get("name") is not None:
+ model_kwargs["load_weight_prefix"] = load_weight_prefix + "/" + model_kwargs.get("name")
+
+ # Instantiate model.
+ model = cls(config, *model_args, **model_kwargs)
+
+ if tf_to_pt_weight_rename is None and hasattr(model, "tf_to_pt_weight_rename"):
+ # TODO Matt: This is a temporary workaround to allow weight renaming, but requires a method
+ # to be defined for each class that requires a rename. We can probably just have a class-level
+ # dict and a single top-level method or something and cut down a lot of boilerplate code
+ tf_to_pt_weight_rename = model.tf_to_pt_weight_rename
+
+ if from_pt:
+ from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
+
+ # Load from a PyTorch checkpoint
+ return load_pytorch_checkpoint_in_tf2_model(
+ model,
+ resolved_archive_file,
+ allow_missing_keys=True,
+ output_loading_info=output_loading_info,
+ _prefix=load_weight_prefix,
+ tf_to_pt_weight_rename=tf_to_pt_weight_rename,
+ )
+
+ # we might need to extend the variable scope for composite models
+ if load_weight_prefix is not None:
+ with tf.compat.v1.variable_scope(load_weight_prefix):
+ model.build_in_name_scope() # build the network with dummy inputs
+ else:
+ model.build_in_name_scope() # build the network with dummy inputs
+
+ if safetensors_from_pt and not is_sharded:
+ from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
+
+ with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
+ # Load from a PyTorch safetensors checkpoint
+ # We load in TF format here because PT weights often need to be transposed, and this is much
+ # faster on GPU. Loading as numpy and transposing on CPU adds several seconds to load times.
+ return load_pytorch_state_dict_in_tf2_model(
+ model,
+ safetensors_archive,
+ tf_inputs=False, # No need to build the model again
+ allow_missing_keys=True,
+ output_loading_info=output_loading_info,
+ _prefix=load_weight_prefix,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ tf_to_pt_weight_rename=tf_to_pt_weight_rename,
+ )
+ elif safetensors_from_pt:
+ from .modeling_tf_pytorch_utils import load_sharded_pytorch_safetensors_in_tf2_model
+
+ return load_sharded_pytorch_safetensors_in_tf2_model(
+ model,
+ resolved_archive_file,
+ tf_inputs=False,
+ allow_missing_keys=True,
+ output_loading_info=output_loading_info,
+ _prefix=load_weight_prefix,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ tf_to_pt_weight_rename=tf_to_pt_weight_rename,
+ )
+
+ # 'by_name' allow us to do transfer learning by skipping/adding layers
+ # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
+ try:
+ if is_sharded:
+ for file in resolved_archive_file:
+ os.path.isfile(file), f"Error retrieving files {file}"
+ if filename == SAFE_WEIGHTS_INDEX_NAME:
+ missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights_from_safetensors(
+ model,
+ resolved_archive_file,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ _prefix=load_weight_prefix,
+ )
+ else:
+ missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights(
+ model,
+ resolved_archive_file,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ _prefix=load_weight_prefix,
+ )
+ else:
+ # Handles both H5 and safetensors
+ missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
+ model,
+ resolved_archive_file,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ _prefix=load_weight_prefix,
+ )
+ except OSError as e:
+ try:
+ with open(resolved_archive_file) as f:
+ if f.read().startswith("version"):
+ raise OSError(
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
+ "you cloned."
+ )
+ else:
+ raise ValueError from e
+ except (UnicodeDecodeError, ValueError):
+ raise OSError(
+ "Unable to load weights from h5 file. "
+ "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
+ )
+
+ if cls._keys_to_ignore_on_load_missing is not None:
+ for pat in cls._keys_to_ignore_on_load_missing:
+ missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
+
+ if cls._keys_to_ignore_on_load_unexpected is not None:
+ for pat in cls._keys_to_ignore_on_load_unexpected:
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
+
+ if len(unexpected_keys) > 0:
+ logger.warning(
+ f"Some layers from the model checkpoint at {pretrained_model_name_or_path} were not used when"
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
+ " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
+ " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
+ )
+ else:
+ logger.warning(f"All model checkpoint layers were used when initializing {model.__class__.__name__}.\n")
+
+ if len(missing_keys) > 0:
+ logger.warning(
+ f"Some layers of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+ )
+ elif len(mismatched_keys) == 0:
+ logger.warning(
+ f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
+ f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
+ " training."
+ )
+ if len(mismatched_keys) > 0:
+ mismatched_warning = "\n".join(
+ [
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
+ for key, shape1, shape2 in mismatched_keys
+ ]
+ )
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
+ " to use it for predictions and inference."
+ )
+
+ # If it is a model with generation capabilities, attempt to load the generation config
+ if model.can_generate():
+ try:
+ model.generation_config = GenerationConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ _from_auto=from_auto_class,
+ _from_pipeline=from_pipeline,
+ **kwargs,
+ )
+ except OSError:
+ logger.info(
+ "Generation config file not found, using a generation config created from the model config."
+ )
+ pass
+
+ if output_loading_info:
+ loading_info = {
+ "missing_keys": missing_keys,
+ "unexpected_keys": unexpected_keys,
+ "mismatched_keys": mismatched_keys,
+ }
+
+ return model, loading_info
+
+ return model
+
+ def push_to_hub(
+ self,
+ repo_id: str,
+ use_temp_dir: bool | None = None,
+ commit_message: str | None = None,
+ private: bool | None = None,
+ max_shard_size: int | str | None = "10GB",
+ token: bool | str | None = None,
+ # (`use_auth_token` is deprecated: we have to keep it here as we don't have **kwargs)
+ use_auth_token: bool | str | None = None,
+ create_pr: bool = False,
+ **base_model_card_args,
+ ) -> str:
+ """
+ Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`.
+
+ Parameters:
+ repo_id (`str`):
+ The name of the repository you want to push your model to. It should contain your organization name
+ when pushing to a given organization.
+ use_temp_dir (`bool`, *optional*):
+ Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub.
+ Will default to `True` if there is no directory named like `repo_id`, `False` otherwise.
+ commit_message (`str`, *optional*):
+ Message to commit while pushing. Will default to `"Upload model"`.
+ private (`bool`, *optional*):
+ Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
+ token (`bool` or `str`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `hf auth login` (stored in `~/.huggingface`). Will default to `True` if `repo_url`
+ is not specified.
+ max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
+ Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
+ will then be each of size lower than this size. If expressed as a string, needs to be digits followed
+ by a unit (like `"5MB"`).
+ create_pr (`bool`, *optional*, defaults to `False`):
+ Whether or not to create a PR with the uploaded files or directly commit.
+
+ Examples:
+
+ ```python
+ from transformers import TFAutoModel
+
+ model = TFAutoModel.from_pretrained("google-bert/bert-base-cased")
+
+ # Push the model to your namespace with the name "my-finetuned-bert".
+ model.push_to_hub("my-finetuned-bert")
+
+ # Push the model to an organization with the name "my-finetuned-bert".
+ model.push_to_hub("huggingface/my-finetuned-bert")
+ ```
+ """
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ if "repo_path_or_name" in base_model_card_args:
+ warnings.warn(
+ "The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
+ "`repo_id` instead."
+ )
+ repo_id = base_model_card_args.pop("repo_path_or_name")
+ # Deprecation warning will be sent after for repo_url and organization
+ repo_url = base_model_card_args.pop("repo_url", None)
+ organization = base_model_card_args.pop("organization", None)
+
+ if os.path.isdir(repo_id):
+ working_dir = repo_id
+ repo_id = repo_id.split(os.path.sep)[-1]
+ else:
+ working_dir = repo_id.split("/")[-1]
+
+ repo_id = self._create_repo(
+ repo_id, private=private, token=token, repo_url=repo_url, organization=organization
+ )
+
+ if use_temp_dir is None:
+ use_temp_dir = not os.path.isdir(working_dir)
+
+ with working_or_temp_dir(working_dir=working_dir, use_temp_dir=use_temp_dir) as work_dir:
+ files_timestamps = self._get_files_timestamps(work_dir)
+
+ # Save all files.
+ self.save_pretrained(work_dir, max_shard_size=max_shard_size)
+ if hasattr(self, "history") and hasattr(self, "create_model_card"):
+ # This is a Keras model and we might be able to fish out its History and make a model card out of it
+ base_model_card_args = {
+ "output_dir": work_dir,
+ "model_name": Path(repo_id).name,
+ }
+ base_model_card_args.update(base_model_card_args)
+ self.create_model_card(**base_model_card_args)
+
+ self._upload_modified_files(
+ work_dir,
+ repo_id,
+ files_timestamps,
+ commit_message=commit_message,
+ token=token,
+ create_pr=create_pr,
+ )
+
+ @classmethod
+ def register_for_auto_class(cls, auto_class="TFAutoModel"):
+ """
+ Register this class with a given auto class. This should only be used for custom models as the ones in the
+ library are already mapped with an auto class.
+
+
+
+ Args:
+ auto_class (`str` or `type`, *optional*, defaults to `"TFAutoModel"`):
+ The auto class to register this new model with.
+ """
+ if not isinstance(auto_class, str):
+ auto_class = auto_class.__name__
+
+ import transformers.models.auto as auto_module
+
+ if not hasattr(auto_module, auto_class):
+ raise ValueError(f"{auto_class} is not a valid auto class.")
+
+ cls._auto_class = auto_class
+
+
+class TFConv1D(keras.layers.Layer):
+ """
+ 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
+
+ Basically works like a linear layer but the weights are transposed.
+
+ Args:
+ nf (`int`):
+ The number of output features.
+ nx (`int`):
+ The number of input features.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation to use to initialize the weights.
+ kwargs (`dict[str, Any]`, *optional*):
+ Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`.
+ """
+
+ def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
+ super().__init__(**kwargs)
+ self.nf = nf
+ self.nx = nx
+ self.initializer_range = initializer_range
+
+ def build(self, input_shape):
+ if self.built:
+ return
+ self.built = True
+ self.weight = self.add_weight(
+ "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
+ )
+ self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer())
+
+ def call(self, x):
+ bz, sl = shape_list(x)[:2]
+
+ x = tf.reshape(x, [-1, self.nx])
+ x = tf.matmul(x, self.weight) + self.bias
+
+ x = tf.reshape(x, [bz, sl, self.nf])
+
+ return x
+
+
+class TFSharedEmbeddings(keras.layers.Layer):
+ r"""
+ Construct shared token embeddings.
+
+ The weights of the embedding layer is usually shared with the weights of the linear decoder when doing language
+ modeling.
+
+ Args:
+ vocab_size (`int`):
+ The size of the vocabulary, e.g., the number of unique tokens.
+ hidden_size (`int`):
+ The size of the embedding vectors.
+ initializer_range (`float`, *optional*):
+ The standard deviation to use when initializing the weights. If no value is provided, it will default to
+ \\(1/\sqrt{hidden\_size}\\).
+ kwargs (`dict[str, Any]`, *optional*):
+ Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`.
+ """
+
+ # TODO (joao): flagged for detection due to embeddings refactor
+
+ def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float | None = None, **kwargs):
+ super().__init__(**kwargs)
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.initializer_range = hidden_size**-0.5 if initializer_range is None else initializer_range
+ warnings.warn(
+ "`TFSharedEmbeddings` is scheduled for deletion in v4.32, use `keras.layers.Embedding` instead.",
+ DeprecationWarning,
+ )
+
+ def build(self, input_shape):
+ """
+ Build shared token embedding layer Shared weights logic adapted from
+ https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
+ """
+ self.weight = self.add_weight(
+ "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
+ )
+ super().build(input_shape)
+
+ def get_config(self):
+ config = {
+ "vocab_size": self.vocab_size,
+ "hidden_size": self.hidden_size,
+ "initializer_range": self.initializer_range,
+ }
+ base_config = super().get_config()
+
+ return dict(list(base_config.items()) + list(config.items()))
+
+ def call(self, inputs: tf.Tensor, mode: str = "embedding") -> tf.Tensor:
+ """
+ Get token embeddings of inputs or decode final hidden state.
+
+ Args:
+ inputs (`tf.Tensor`):
+ In embedding mode, should be an int64 tensor with shape `[batch_size, length]`.
+
+ In linear mode, should be a float tensor with shape `[batch_size, length, hidden_size]`.
+ mode (`str`, defaults to `"embedding"`):
+ A valid value is either `"embedding"` or `"linear"`, the first one indicates that the layer should be
+ used as an embedding layer, the second one that the layer should be used as a linear decoder.
+
+ Returns:
+ `tf.Tensor`: In embedding mode, the output is a float32 embedding tensor, with shape `[batch_size, length,
+ embedding_size]`.
+
+ In linear mode, the output is a float32 with shape `[batch_size, length, vocab_size]`.
+
+ Raises:
+ ValueError: if `mode` is not valid.
+
+ Shared weights logic is adapted from
+ [here](https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24).
+ """
+ if mode == "embedding":
+ return self._embedding(inputs)
+ elif mode == "linear":
+ return self._linear(inputs)
+ else:
+ raise ValueError(f"mode {mode} is not valid.")
+
+ def _embedding(self, input_ids):
+ """Applies embedding based on inputs tensor."""
+ return tf.gather(self.weight, input_ids)
+
+ def _linear(self, inputs):
+ """
+ Computes logits by running inputs through a linear layer.
+
+ Args:
+ inputs: A float32 tensor with shape [..., hidden_size]
+
+ Returns:
+ float32 tensor with shape [..., vocab_size].
+ """
+ first_dims = shape_list(inputs)[:-1]
+ x = tf.reshape(inputs, [-1, self.hidden_size])
+ logits = tf.matmul(x, self.weight, transpose_b=True)
+
+ return tf.reshape(logits, first_dims + [self.vocab_size])
+
+
+class TFSequenceSummary(keras.layers.Layer):
+ """
+ Compute a single vector summary of a sequence hidden states.
+
+ Args:
+ config ([`PretrainedConfig`]):
+ The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
+ config class of your model for the default values it uses):
+
+ - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
+
+ - `"last"` -- Take the last token hidden state (like XLNet)
+ - `"first"` -- Take the first token hidden state (like Bert)
+ - `"mean"` -- Take the mean of all tokens hidden states
+ - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
+ - `"attn"` -- Not implemented now, use multi-head attention
+
+ - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
+ - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
+ (otherwise to `config.hidden_size`).
+ - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
+ another string or `None` will add no activation.
+ - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
+ - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
+
+ initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation to use to initialize the weights.
+ kwargs (`dict[str, Any]`, *optional*):
+ Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`.
+ """
+
+ def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, **kwargs):
+ super().__init__(**kwargs)
+
+ self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last"
+ if self.summary_type == "attn":
+ # We should use a standard multi-head attention module with absolute positional embedding for that.
+ # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
+ # We can probably just use the multi-head attention module of PyTorch >=1.1.0
+ raise NotImplementedError
+
+ self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj
+ if self.has_summary:
+ if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
+ num_classes = config.num_labels
+ else:
+ num_classes = config.hidden_size
+ self.summary = keras.layers.Dense(
+ num_classes, kernel_initializer=get_initializer(initializer_range), name="summary"
+ )
+
+ self.has_activation = False
+ activation_string = getattr(config, "summary_activation", None)
+ if activation_string is not None:
+ self.has_activation = True
+ self.activation = get_tf_activation(activation_string)
+
+ self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0
+ if self.has_first_dropout:
+ self.first_dropout = keras.layers.Dropout(config.summary_first_dropout)
+
+ self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
+ if self.has_last_dropout:
+ self.last_dropout = keras.layers.Dropout(config.summary_last_dropout)
+ self.hidden_size = config.hidden_size
+
+ def call(self, inputs, cls_index=None, training=False):
+ if not isinstance(inputs, (dict, tuple, list)):
+ hidden_states = inputs
+ elif isinstance(inputs, (tuple, list)):
+ hidden_states = inputs[0]
+ cls_index = inputs[1] if len(inputs) > 1 else None
+ assert len(inputs) <= 2, "Too many inputs."
+ else:
+ hidden_states = inputs.get("hidden_states")
+ cls_index = inputs.get("cls_index", None)
+
+ if self.summary_type == "last":
+ output = hidden_states[:, -1]
+ elif self.summary_type == "first":
+ output = hidden_states[:, 0]
+ elif self.summary_type == "mean":
+ output = tf.reduce_mean(hidden_states, axis=1)
+ elif self.summary_type == "cls_index":
+ hidden_shape = shape_list(hidden_states) # e.g. [batch, num choices, seq length, hidden dims]
+ if cls_index is None:
+ cls_index = tf.fill(
+ hidden_shape[:-2], hidden_shape[-2] - 1
+ ) # A tensor full of shape [batch] or [batch, num choices] full of sequence length
+ cls_shape = shape_list(cls_index)
+ if len(cls_shape) <= len(hidden_shape) - 2:
+ cls_index = tf.expand_dims(cls_index, axis=-1)
+ # else:
+ # cls_index = cls_index[..., tf.newaxis]
+ # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
+ # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
+ output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
+ output = tf.squeeze(
+ output, axis=len(hidden_shape) - 2
+ ) # shape of output: (batch, num choices, hidden_size)
+ elif self.summary_type == "attn":
+ raise NotImplementedError
+
+ if self.has_first_dropout:
+ output = self.first_dropout(output, training=training)
+
+ if self.has_summary:
+ output = self.summary(output)
+
+ if self.has_activation:
+ output = self.activation(output)
+
+ if self.has_last_dropout:
+ output = self.last_dropout(output, training=training)
+
+ return output
+
+ def build(self, input_shape):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "summary", None) is not None:
+ with tf.name_scope("summary"):
+ self.summary.build(self.hidden_size)
+
+
+def get_initializer(initializer_range: float = 0.02) -> keras.initializers.TruncatedNormal:
+ """
+ Creates a `keras.initializers.TruncatedNormal` with the given range.
+
+ Args:
+ initializer_range (*float*, defaults to 0.02): Standard deviation of the initializer range.
+
+ Returns:
+ `keras.initializers.TruncatedNormal`: The truncated normal initializer.
+ """
+ return keras.initializers.TruncatedNormal(stddev=initializer_range)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/optimization.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/optimization.py
new file mode 100644
index 0000000000000000000000000000000000000000..688d0f8db56f393910dc05bb35b53213b46ced2b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/optimization.py
@@ -0,0 +1,973 @@
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch optimization for BERT model."""
+
+import math
+import warnings
+from functools import partial
+from typing import Optional, Union
+
+import torch
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
+
+from .trainer_pt_utils import LayerWiseDummyOptimizer, LayerWiseDummyScheduler
+from .trainer_utils import SchedulerType
+from .utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def _get_constant_lambda(_=None):
+ return 1
+
+
+def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
+ """
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ return LambdaLR(optimizer, _get_constant_lambda, last_epoch=last_epoch)
+
+
+def get_reduce_on_plateau_schedule(optimizer: Optimizer, **kwargs):
+ """
+ Create a schedule with a constant learning rate that decreases when a metric has stopped improving.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ kwargs (`dict`, *optional*):
+ Extra parameters to be passed to the scheduler. See `torch.optim.lr_scheduler.ReduceLROnPlateau`
+ for possible parameters.
+
+ Return:
+ `torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule.
+ """
+
+ return ReduceLROnPlateau(optimizer, **kwargs)
+
+
+def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1.0, num_warmup_steps))
+ return 1.0
+
+
+def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
+ """
+ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
+ increases linearly between 0 and the initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ lr_lambda = partial(_get_constant_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps)
+ return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
+
+
+def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
+
+
+def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
+ """
+ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
+ a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ lr_lambda = partial(
+ _get_linear_schedule_with_warmup_lr_lambda,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ )
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def _get_cosine_schedule_with_warmup_lr_lambda(
+ current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float
+):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
+
+
+def get_cosine_schedule_with_warmup(
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
+):
+ """
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
+ initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ num_cycles (`float`, *optional*, defaults to 0.5):
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
+ following a half-cosine).
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ lr_lambda = partial(
+ _get_cosine_schedule_with_warmup_lr_lambda,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ num_cycles=num_cycles,
+ )
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda(
+ current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: int
+):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ if progress >= 1.0:
+ return 0.0
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
+
+
+def get_cosine_with_hard_restarts_schedule_with_warmup(
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
+):
+ """
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
+ initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
+ linearly between 0 and the initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ num_cycles (`int`, *optional*, defaults to 1):
+ The number of hard restarts to use.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ lr_lambda = partial(
+ _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ num_cycles=num_cycles,
+ )
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def _get_polynomial_decay_schedule_with_warmup_lr_lambda(
+ current_step: int,
+ *,
+ num_warmup_steps: int,
+ num_training_steps: int,
+ lr_end: float,
+ power: float,
+ lr_init: int,
+):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ elif current_step > num_training_steps:
+ return lr_end / lr_init # as LambdaLR multiplies by lr_init
+ else:
+ lr_range = lr_init - lr_end
+ decay_steps = num_training_steps - num_warmup_steps
+ pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
+ decay = lr_range * pct_remaining**power + lr_end
+ return decay / lr_init # as LambdaLR multiplies by lr_init
+
+
+def get_polynomial_decay_schedule_with_warmup(
+ optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
+):
+ """
+ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
+ optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
+ initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ lr_end (`float`, *optional*, defaults to 1e-7):
+ The end LR.
+ power (`float`, *optional*, defaults to 1.0):
+ Power factor.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
+ implementation at
+ https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+
+ """
+
+ lr_init = optimizer.defaults["lr"]
+ if not (lr_init > lr_end):
+ raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})")
+
+ lr_lambda = partial(
+ _get_polynomial_decay_schedule_with_warmup_lr_lambda,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ lr_end=lr_end,
+ power=power,
+ lr_init=lr_init,
+ )
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: Optional[int] = None):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ shift = timescale - num_warmup_steps
+ decay = 1.0 / math.sqrt((current_step + shift) / timescale)
+ return decay
+
+
+def get_inverse_sqrt_schedule(
+ optimizer: Optimizer, num_warmup_steps: int, timescale: Optional[int] = None, last_epoch: int = -1
+):
+ """
+ Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a
+ warmup period which increases lr linearly from 0 to the initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ timescale (`int`, *optional*, defaults to `num_warmup_steps`):
+ Time scale.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+ # Note: this implementation is adapted from
+ # https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930
+
+ if timescale is None:
+ timescale = num_warmup_steps or 10_000
+
+ lr_lambda = partial(_get_inverse_sqrt_schedule_lr_lambda, num_warmup_steps=num_warmup_steps, timescale=timescale)
+ return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
+
+
+def _get_cosine_schedule_with_warmup_lr_lambda(
+ current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0
+):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
+ factor = factor * (1 - min_lr_rate) + min_lr_rate
+ return max(0, factor)
+
+
+def get_cosine_with_min_lr_schedule_with_warmup(
+ optimizer: Optimizer,
+ num_warmup_steps: int,
+ num_training_steps: int,
+ num_cycles: float = 0.5,
+ last_epoch: int = -1,
+ min_lr: Optional[float] = None,
+ min_lr_rate: Optional[float] = None,
+):
+ """
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
+ initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0 and the
+ initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ num_cycles (`float`, *optional*, defaults to 0.5):
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
+ following a half-cosine).
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+ min_lr (`float`, *optional*):
+ The minimum learning rate to reach after the cosine schedule.
+ min_lr_rate (`float`, *optional*):
+ The minimum learning rate as a ratio of the initial learning rate. If set, `min_lr` should not be set.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ if min_lr is not None and min_lr_rate is not None:
+ raise ValueError("Only one of min_lr or min_lr_rate should be set")
+ elif min_lr is not None:
+ min_lr_rate = min_lr / optimizer.defaults["lr"]
+ elif min_lr_rate is None:
+ raise ValueError("One of min_lr or min_lr_rate should be set through the `lr_scheduler_kwargs`")
+
+ lr_lambda = partial(
+ _get_cosine_schedule_with_warmup_lr_lambda,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ num_cycles=num_cycles,
+ min_lr_rate=min_lr_rate,
+ )
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def _get_cosine_with_min_lr_schedule_with_warmup_lr_rate_lambda(
+ current_step: int,
+ *,
+ num_warmup_steps: int,
+ num_training_steps: int,
+ num_cycles: float,
+ min_lr_rate: float = 0.0,
+ warmup_lr_rate: Optional[float] = None,
+):
+ current_step = float(current_step)
+ num_warmup_steps = float(num_warmup_steps)
+ num_training_steps = float(num_training_steps)
+
+ if current_step < num_warmup_steps:
+ if warmup_lr_rate is None:
+ return (current_step + 1.0) / max(1.0, num_warmup_steps)
+ else:
+ warmup_lr_rate = float(warmup_lr_rate)
+ return warmup_lr_rate + (1.0 - warmup_lr_rate) * (current_step) / (max(1, num_warmup_steps - 1))
+ progress = (current_step - num_warmup_steps + 1.0) / (max(1.0, num_training_steps - num_warmup_steps))
+ factor = 0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress))
+ factor = factor * (1 - min_lr_rate) + min_lr_rate
+ return max(0, factor)
+
+
+def get_cosine_with_min_lr_schedule_with_warmup_lr_rate(
+ optimizer: Optimizer,
+ num_warmup_steps: int,
+ num_training_steps: int,
+ num_cycles: float = 0.5,
+ last_epoch: int = -1,
+ min_lr: Optional[float] = None,
+ min_lr_rate: Optional[float] = None,
+ warmup_lr_rate: Optional[float] = None,
+):
+ """
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
+ initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0 and the
+ initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ num_cycles (`float`, *optional*, defaults to 0.5):
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
+ following a half-cosine).
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+ min_lr (`float`, *optional*):
+ The minimum learning rate to reach after the cosine schedule.
+ min_lr_rate (`float`, *optional*):
+ The minimum learning rate as a ratio of the initial learning rate. If set, `min_lr` should not be set.
+ warmup_lr_rate (`float`, *optional*):
+ The minimum learning rate as a ratio of the start learning rate. If not set, `warmup_lr_rate` will be treated as float(1/num_warmup_steps).
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ if min_lr is not None and min_lr_rate is not None:
+ raise ValueError("Only one of min_lr or min_lr_rate should be set")
+ elif min_lr is not None:
+ min_lr_rate = min_lr / optimizer.defaults["lr"]
+ elif min_lr_rate is None:
+ raise ValueError("One of min_lr or min_lr_rate should be set through the `lr_scheduler_kwargs`")
+
+ lr_lambda = partial(
+ _get_cosine_with_min_lr_schedule_with_warmup_lr_rate_lambda,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ num_cycles=num_cycles,
+ min_lr_rate=min_lr_rate,
+ warmup_lr_rate=warmup_lr_rate,
+ )
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def _get_wsd_scheduler_lambda(
+ current_step: int,
+ *,
+ num_warmup_steps: int,
+ num_stable_steps: int,
+ num_decay_steps: int,
+ warmup_type: str,
+ decay_type: str,
+ min_lr_ratio: float,
+ num_cycles: float,
+):
+ if current_step < num_warmup_steps:
+ progress = float(current_step) / float(max(1, num_warmup_steps))
+ if warmup_type == "linear":
+ factor = progress
+ elif warmup_type == "cosine":
+ factor = 0.5 * (1.0 - math.cos(math.pi * progress))
+ elif warmup_type == "1-sqrt":
+ factor = 1.0 - math.sqrt(1.0 - progress)
+ factor = factor * (1.0 - min_lr_ratio) + min_lr_ratio
+ return max(0.0, factor)
+
+ if current_step < num_warmup_steps + num_stable_steps:
+ return 1.0
+
+ if current_step < num_warmup_steps + num_stable_steps + num_decay_steps:
+ progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps))
+ if decay_type == "linear":
+ factor = 1.0 - progress
+ elif decay_type == "cosine":
+ factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
+ elif decay_type == "1-sqrt":
+ factor = 1.0 - math.sqrt(progress)
+ factor = factor * (1.0 - min_lr_ratio) + min_lr_ratio
+ return max(0.0, factor)
+ return min_lr_ratio
+
+
+def get_wsd_schedule(
+ optimizer: Optimizer,
+ num_warmup_steps: int,
+ num_decay_steps: int,
+ num_training_steps: Optional[int] = None,
+ num_stable_steps: Optional[int] = None,
+ warmup_type: str = "linear",
+ decay_type: str = "cosine",
+ min_lr_ratio: float = 0,
+ num_cycles: float = 0.5,
+ last_epoch: int = -1,
+):
+ """
+ Create a schedule with a learning rate that has three stages:
+ 1. warmup: increase from min_lr_ratio times the initial learning rate to the initial learning rate following a warmup_type.
+ 2. stable: constant learning rate.
+ 3. decay: decrease from the initial learning rate to min_lr_ratio times the initial learning rate following a decay_type.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_decay_steps (`int`):
+ The number of steps for the decay phase.
+ num_training_steps (`int`, *optional*):
+ The total number of training steps. This is the sum of the warmup, stable and decay steps. If `num_stable_steps` is not provided, the stable phase will be `num_training_steps - num_warmup_steps - num_decay_steps`.
+ num_stable_steps (`int`, *optional*):
+ The number of steps for the stable phase. Please ensure that `num_warmup_steps + num_stable_steps + num_decay_steps` equals `num_training_steps`, otherwise the other steps will default to the minimum learning rate.
+ warmup_type (`str`, *optional*, defaults to "linear"):
+ The type of warmup to use. Can be 'linear', 'cosine' or '1-sqrt'.
+ decay_type (`str`, *optional*, defaults to "cosine"):
+ The type of decay to use. Can be 'linear', 'cosine' or '1-sqrt'.
+ min_lr_ratio (`float`, *optional*, defaults to 0):
+ The minimum learning rate as a ratio of the initial learning rate.
+ num_cycles (`float`, *optional*, defaults to 0.5):
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
+ following a half-cosine).
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ if num_training_steps is None and num_stable_steps is None:
+ raise ValueError("Either num_training_steps or num_stable_steps must be specified.")
+
+ if num_training_steps is not None and num_stable_steps is not None:
+ warnings.warn("Both num_training_steps and num_stable_steps are specified. num_stable_steps will be used.")
+
+ if warmup_type not in ["linear", "cosine", "1-sqrt"]:
+ raise ValueError(f"Unknown warmup type: {warmup_type}, expected 'linear', 'cosine' or '1-sqrt'")
+
+ if decay_type not in ["linear", "cosine", "1-sqrt"]:
+ raise ValueError(f"Unknown decay type: {decay_type}, expected 'linear', 'cosine' or '1-sqrt'")
+
+ if num_stable_steps is None:
+ num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
+
+ lr_lambda = partial(
+ _get_wsd_scheduler_lambda,
+ num_warmup_steps=num_warmup_steps,
+ num_stable_steps=num_stable_steps,
+ num_decay_steps=num_decay_steps,
+ warmup_type=warmup_type,
+ decay_type=decay_type,
+ min_lr_ratio=min_lr_ratio,
+ num_cycles=num_cycles,
+ )
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+TYPE_TO_SCHEDULER_FUNCTION = {
+ SchedulerType.LINEAR: get_linear_schedule_with_warmup,
+ SchedulerType.COSINE: get_cosine_schedule_with_warmup,
+ SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
+ SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
+ SchedulerType.CONSTANT: get_constant_schedule,
+ SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
+ SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
+ SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule,
+ SchedulerType.COSINE_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup,
+ SchedulerType.COSINE_WARMUP_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup_lr_rate,
+ SchedulerType.WARMUP_STABLE_DECAY: get_wsd_schedule,
+}
+
+
+def get_scheduler(
+ name: Union[str, SchedulerType],
+ optimizer: Optimizer,
+ num_warmup_steps: Optional[int] = None,
+ num_training_steps: Optional[int] = None,
+ scheduler_specific_kwargs: Optional[dict] = None,
+):
+ """
+ Unified API to get any scheduler from its name.
+
+ Args:
+ name (`str` or `SchedulerType`):
+ The name of the scheduler to use.
+ optimizer (`torch.optim.Optimizer`):
+ The optimizer that will be used during training.
+ num_warmup_steps (`int`, *optional*):
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
+ num_training_steps (`int``, *optional*):
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
+ scheduler_specific_kwargs (`dict`, *optional*):
+ Extra parameters for schedulers such as cosine with restarts. Mismatched scheduler types and scheduler
+ parameters will cause the scheduler function to raise a TypeError.
+ """
+ name = SchedulerType(name)
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
+
+ # If a `LayerWiseDummyOptimizer` is passed we extract the optimizer dict and
+ # recursively call `get_scheduler` to get the proper schedulers on each parameter
+ if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer):
+ optimizer_dict = optimizer.optimizer_dict
+ scheduler_dict = {}
+
+ for param in optimizer_dict:
+ scheduler_dict[param] = get_scheduler(
+ name,
+ optimizer=optimizer_dict[param],
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ scheduler_specific_kwargs=scheduler_specific_kwargs,
+ )
+
+ def scheduler_hook(param):
+ # Since the optimizer hook has been already attached we only need to
+ # attach the scheduler hook, the gradients have been zeroed here
+ scheduler_dict[param].step()
+
+ for param in optimizer_dict:
+ if param.requires_grad:
+ param.register_post_accumulate_grad_hook(scheduler_hook)
+
+ return LayerWiseDummyScheduler(optimizer_dict=optimizer_dict, lr=optimizer.defaults["lr"])
+
+ if name == SchedulerType.CONSTANT:
+ return schedule_func(optimizer)
+
+ if scheduler_specific_kwargs is None:
+ scheduler_specific_kwargs = {}
+
+ if name == SchedulerType.REDUCE_ON_PLATEAU:
+ return schedule_func(optimizer, **scheduler_specific_kwargs)
+
+ # All other schedulers require `num_warmup_steps`
+ if num_warmup_steps is None:
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
+
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
+
+ if name == SchedulerType.INVERSE_SQRT:
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
+
+ # wsd scheduler requires either num_training_steps or num_stable_steps
+ if name == SchedulerType.WARMUP_STABLE_DECAY:
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ **scheduler_specific_kwargs,
+ )
+
+ # All other schedulers require `num_training_steps`
+ if num_training_steps is None:
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
+
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ **scheduler_specific_kwargs,
+ )
+
+
+class Adafactor(Optimizer):
+ """
+ AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
+ https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
+
+ Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://huggingface.co/papers/1804.04235 Note that
+ this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
+ `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
+ `relative_step=False`.
+
+ Arguments:
+ params (`Iterable[nn.parameter.Parameter]`):
+ Iterable of parameters to optimize or dictionaries defining parameter groups.
+ lr (`float`, *optional*):
+ The external learning rate.
+ eps (`tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`):
+ Regularization constants for square gradient and parameter scale respectively
+ clip_threshold (`float`, *optional*, defaults to 1.0):
+ Threshold of root mean square of final gradient update
+ decay_rate (`float`, *optional*, defaults to -0.8):
+ Coefficient used to compute running averages of square
+ beta1 (`float`, *optional*):
+ Coefficient used for computing running averages of gradient
+ weight_decay (`float`, *optional*, defaults to 0.0):
+ Weight decay (L2 penalty)
+ scale_parameter (`bool`, *optional*, defaults to `True`):
+ If True, learning rate is scaled by root mean square
+ relative_step (`bool`, *optional*, defaults to `True`):
+ If True, time-dependent learning rate is computed instead of external learning rate
+ warmup_init (`bool`, *optional*, defaults to `False`):
+ Time-dependent learning rate computation depends on whether warm-up initialization is being used
+
+ This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
+
+ Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
+
+ - Training without LR warmup or clip_threshold is not recommended.
+
+ - use scheduled LR warm-up to fixed LR
+ - use clip_threshold=1.0 (https://huggingface.co/papers/1804.04235)
+ - Disable relative updates
+ - Use scale_parameter=False
+ - Additional optimizer operations like gradient clipping should not be used alongside Adafactor
+
+ Example:
+
+ ```python
+ Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
+ ```
+
+ Others reported the following combination to work well:
+
+ ```python
+ Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
+ ```
+
+ When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
+ scheduler as following:
+
+ ```python
+ from transformers.optimization import Adafactor, AdafactorSchedule
+
+ optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
+ lr_scheduler = AdafactorSchedule(optimizer)
+ trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
+ ```
+
+ Usage:
+
+ ```python
+ # replace AdamW with Adafactor
+ optimizer = Adafactor(
+ model.parameters(),
+ lr=1e-3,
+ eps=(1e-30, 1e-3),
+ clip_threshold=1.0,
+ decay_rate=-0.8,
+ beta1=None,
+ weight_decay=0.0,
+ relative_step=False,
+ scale_parameter=False,
+ warmup_init=False,
+ )
+ ```"""
+
+ def __init__(
+ self,
+ params,
+ lr=None,
+ eps=(1e-30, 1e-3),
+ clip_threshold=1.0,
+ decay_rate=-0.8,
+ beta1=None,
+ weight_decay=0.0,
+ scale_parameter=True,
+ relative_step=True,
+ warmup_init=False,
+ ):
+ if lr is not None and relative_step:
+ raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
+ if warmup_init and not relative_step:
+ raise ValueError("`warmup_init=True` requires `relative_step=True`")
+
+ defaults = {
+ "lr": lr,
+ "eps": eps,
+ "clip_threshold": clip_threshold,
+ "decay_rate": decay_rate,
+ "beta1": beta1,
+ "weight_decay": weight_decay,
+ "scale_parameter": scale_parameter,
+ "relative_step": relative_step,
+ "warmup_init": warmup_init,
+ }
+ super().__init__(params, defaults)
+
+ @staticmethod
+ def _get_lr(param_group, param_state):
+ rel_step_sz = param_group["lr"]
+ if param_group["relative_step"]:
+ min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
+ rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
+ param_scale = 1.0
+ if param_group["scale_parameter"]:
+ param_scale = max(param_group["eps"][1], param_state["RMS"])
+ return param_scale * rel_step_sz
+
+ @staticmethod
+ def _get_options(param_group, param_shape):
+ factored = len(param_shape) >= 2
+ use_first_moment = param_group["beta1"] is not None
+ return factored, use_first_moment
+
+ @staticmethod
+ def _rms(tensor):
+ return tensor.norm(2) / (tensor.numel() ** 0.5)
+
+ @staticmethod
+ def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
+ # copy from fairseq's adafactor implementation:
+ # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
+ r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
+ c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
+ return torch.mul(r_factor, c_factor)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """
+ Performs a single optimization step
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is None:
+ continue
+ grad = p.grad
+ if grad.dtype in {torch.float16, torch.bfloat16}:
+ grad = grad.float()
+ if grad.is_sparse:
+ raise RuntimeError("Adafactor does not support sparse gradients.")
+
+ state = self.state[p]
+ grad_shape = grad.shape
+
+ factored, use_first_moment = self._get_options(group, grad_shape)
+ # State Initialization
+ if len(state) == 0:
+ state["step"] = 0
+
+ if use_first_moment:
+ # Exponential moving average of gradient values
+ state["exp_avg"] = torch.zeros_like(grad)
+ if factored:
+ state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
+ state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
+ else:
+ state["exp_avg_sq"] = torch.zeros_like(grad)
+
+ state["RMS"] = 0
+ else:
+ if use_first_moment:
+ state["exp_avg"] = state["exp_avg"].to(grad)
+ if factored:
+ state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
+ state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
+ else:
+ state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
+
+ p_data_fp32 = p
+ if p.dtype in {torch.float16, torch.bfloat16}:
+ p_data_fp32 = p_data_fp32.float()
+
+ state["step"] += 1
+ state["RMS"] = self._rms(p_data_fp32)
+ lr = self._get_lr(group, state)
+
+ beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
+ update = (grad**2) + group["eps"][0]
+ if factored:
+ exp_avg_sq_row = state["exp_avg_sq_row"]
+ exp_avg_sq_col = state["exp_avg_sq_col"]
+
+ exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
+ exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
+
+ # Approximation of exponential moving average of square of gradient
+ update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
+ update.mul_(grad)
+ else:
+ exp_avg_sq = state["exp_avg_sq"]
+
+ exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
+ update = exp_avg_sq.rsqrt().mul_(grad)
+
+ update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
+ update.mul_(lr)
+
+ if use_first_moment:
+ exp_avg = state["exp_avg"]
+ exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
+ update = exp_avg
+
+ if group["weight_decay"] != 0:
+ p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
+
+ p_data_fp32.add_(-update)
+
+ if p.dtype in {torch.float16, torch.bfloat16}:
+ p.copy_(p_data_fp32)
+
+ return loss
+
+
+class AdafactorSchedule(LambdaLR):
+ """
+ Since [`~optimization.Adafactor`] performs its own scheduling, if the training loop relies on a scheduler (e.g.,
+ for logging), this class creates a proxy object that retrieves the current lr values from the optimizer.
+
+ It returns `initial_lr` during startup and the actual `lr` during stepping.
+ """
+
+ def __init__(self, optimizer, initial_lr=0.0):
+ def lr_lambda(_):
+ return initial_lr
+
+ for group in optimizer.param_groups:
+ group["initial_lr"] = initial_lr
+ super().__init__(optimizer, lr_lambda)
+ for group in optimizer.param_groups:
+ del group["initial_lr"]
+
+ def get_lr(self):
+ opt = self.optimizer
+ lrs = [
+ opt._get_lr(group, opt.state[group["params"][0]])
+ for group in opt.param_groups
+ if group["params"][0].grad is not None
+ ]
+ if len(lrs) == 0:
+ lrs = self.base_lrs # if called before stepping
+ return lrs
+
+
+def get_adafactor_schedule(optimizer, initial_lr=0.0):
+ """
+ Get a proxy schedule for [`~optimization.Adafactor`]
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ initial_lr (`float`, *optional*, defaults to 0.0):
+ Initial lr
+
+ Return:
+ [`~optimization.Adafactor`] proxy schedule object.
+
+
+ """
+ return AdafactorSchedule(optimizer, initial_lr)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/optimization_tf.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/optimization_tf.py
new file mode 100644
index 0000000000000000000000000000000000000000..71a77251f2bf9431a08295b8daabbcbe576de71b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/optimization_tf.py
@@ -0,0 +1,378 @@
+# Copyright 2019 The TensorFlow Authors, The Hugging Face Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions and classes related to optimization (weight updates)."""
+
+from typing import Callable, Optional, Union
+
+import tensorflow as tf
+
+
+try:
+ from tf_keras.optimizers.legacy import Adam
+except (ImportError, ModuleNotFoundError):
+ from tensorflow.keras.optimizers.legacy import Adam
+
+from .modeling_tf_utils import keras
+
+
+# This block because Keras loves randomly moving things to different places - this changed somewhere between 2.10 - 2.15
+if hasattr(keras.optimizers.schedules, "learning_rate_schedule"):
+ schedules = keras.optimizers.schedules.learning_rate_schedule
+else:
+ schedules = keras.optimizers.schedules
+
+
+class WarmUp(schedules.LearningRateSchedule):
+ """
+ Applies a warmup schedule on a given learning rate decay schedule.
+
+ Args:
+ initial_learning_rate (`float`):
+ The initial learning rate for the schedule after the warmup (so this will be the learning rate at the end
+ of the warmup).
+ decay_schedule_fn (`Callable`):
+ The schedule function to apply after the warmup for the rest of training.
+ warmup_steps (`int`):
+ The number of steps for the warmup part of training.
+ power (`float`, *optional*, defaults to 1.0):
+ The power to use for the polynomial warmup (defaults is a linear warmup).
+ name (`str`, *optional*):
+ Optional name prefix for the returned tensors during the schedule.
+ """
+
+ def __init__(
+ self,
+ initial_learning_rate: float,
+ decay_schedule_fn: Callable,
+ warmup_steps: int,
+ power: float = 1.0,
+ name: Optional[str] = None,
+ ):
+ super().__init__()
+ self.initial_learning_rate = initial_learning_rate
+ self.warmup_steps = warmup_steps
+ self.power = power
+ self.decay_schedule_fn = decay_schedule_fn
+ self.name = name
+
+ def __call__(self, step):
+ with tf.name_scope(self.name or "WarmUp") as name:
+ # Implements polynomial warmup. i.e., if global_step < warmup_steps, the
+ # learning rate will be `global_step/num_warmup_steps * init_lr`.
+ global_step_float = tf.cast(step, tf.float32)
+ warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
+ warmup_percent_done = global_step_float / warmup_steps_float
+ warmup_learning_rate = self.initial_learning_rate * tf.math.pow(warmup_percent_done, self.power)
+ return tf.cond(
+ global_step_float < warmup_steps_float,
+ lambda: warmup_learning_rate,
+ lambda: self.decay_schedule_fn(step - self.warmup_steps),
+ name=name,
+ )
+
+ def get_config(self):
+ return {
+ "initial_learning_rate": self.initial_learning_rate,
+ "decay_schedule_fn": self.decay_schedule_fn,
+ "warmup_steps": self.warmup_steps,
+ "power": self.power,
+ "name": self.name,
+ }
+
+
+def create_optimizer(
+ init_lr: float,
+ num_train_steps: int,
+ num_warmup_steps: int,
+ min_lr_ratio: float = 0.0,
+ adam_beta1: float = 0.9,
+ adam_beta2: float = 0.999,
+ adam_epsilon: float = 1e-8,
+ adam_clipnorm: Optional[float] = None,
+ adam_global_clipnorm: Optional[float] = None,
+ weight_decay_rate: float = 0.0,
+ power: float = 1.0,
+ include_in_weight_decay: Optional[list[str]] = None,
+):
+ """
+ Creates an optimizer with a learning rate schedule using a warmup phase followed by a linear decay.
+
+ Args:
+ init_lr (`float`):
+ The desired learning rate at the end of the warmup phase.
+ num_train_steps (`int`):
+ The total number of training steps.
+ num_warmup_steps (`int`):
+ The number of warmup steps.
+ min_lr_ratio (`float`, *optional*, defaults to 0):
+ The final learning rate at the end of the linear decay will be `init_lr * min_lr_ratio`.
+ adam_beta1 (`float`, *optional*, defaults to 0.9):
+ The beta1 to use in Adam.
+ adam_beta2 (`float`, *optional*, defaults to 0.999):
+ The beta2 to use in Adam.
+ adam_epsilon (`float`, *optional*, defaults to 1e-8):
+ The epsilon to use in Adam.
+ adam_clipnorm (`float`, *optional*, defaults to `None`):
+ If not `None`, clip the gradient norm for each weight tensor to this value.
+ adam_global_clipnorm (`float`, *optional*, defaults to `None`)
+ If not `None`, clip gradient norm to this value. When using this argument, the norm is computed over all
+ weight tensors, as if they were concatenated into a single vector.
+ weight_decay_rate (`float`, *optional*, defaults to 0):
+ The weight decay to use.
+ power (`float`, *optional*, defaults to 1.0):
+ The power to use for PolynomialDecay.
+ include_in_weight_decay (`list[str]`, *optional*):
+ List of the parameter names (or re patterns) to apply weight decay to. If none is passed, weight decay is
+ applied to all parameters except bias and layer norm parameters.
+ """
+ # Implements linear decay of the learning rate.
+ lr_schedule = schedules.PolynomialDecay(
+ initial_learning_rate=init_lr,
+ decay_steps=num_train_steps - num_warmup_steps,
+ end_learning_rate=init_lr * min_lr_ratio,
+ power=power,
+ )
+ if num_warmup_steps:
+ lr_schedule = WarmUp(
+ initial_learning_rate=init_lr,
+ decay_schedule_fn=lr_schedule,
+ warmup_steps=num_warmup_steps,
+ )
+ if weight_decay_rate > 0.0:
+ optimizer = AdamWeightDecay(
+ learning_rate=lr_schedule,
+ weight_decay_rate=weight_decay_rate,
+ beta_1=adam_beta1,
+ beta_2=adam_beta2,
+ epsilon=adam_epsilon,
+ clipnorm=adam_clipnorm,
+ global_clipnorm=adam_global_clipnorm,
+ exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
+ include_in_weight_decay=include_in_weight_decay,
+ )
+ else:
+ optimizer = keras.optimizers.Adam(
+ learning_rate=lr_schedule,
+ beta_1=adam_beta1,
+ beta_2=adam_beta2,
+ epsilon=adam_epsilon,
+ clipnorm=adam_clipnorm,
+ global_clipnorm=adam_global_clipnorm,
+ )
+ # We return the optimizer and the LR scheduler in order to better track the
+ # evolution of the LR independently of the optimizer.
+ return optimizer, lr_schedule
+
+
+class AdamWeightDecay(Adam):
+ """
+ Adam enables L2 weight decay and clip_by_global_norm on gradients. Just adding the square of the weights to the
+ loss function is *not* the correct way of using L2 regularization/weight decay with Adam, since that will interact
+ with the m and v parameters in strange ways as shown in [Decoupled Weight Decay
+ Regularization](https://huggingface.co/papers/1711.05101).
+
+ Instead we want to decay the weights in a manner that doesn't interact with the m/v parameters. This is equivalent
+ to adding the square of the weights to the loss with plain (non-momentum) SGD.
+
+ Args:
+ learning_rate (`Union[float, LearningRateSchedule]`, *optional*, defaults to 0.001):
+ The learning rate to use or a schedule.
+ beta_1 (`float`, *optional*, defaults to 0.9):
+ The beta1 parameter in Adam, which is the exponential decay rate for the 1st momentum estimates.
+ beta_2 (`float`, *optional*, defaults to 0.999):
+ The beta2 parameter in Adam, which is the exponential decay rate for the 2nd momentum estimates.
+ epsilon (`float`, *optional*, defaults to 1e-07):
+ The epsilon parameter in Adam, which is a small constant for numerical stability.
+ amsgrad (`bool`, *optional*, defaults to `False`):
+ Whether to apply AMSGrad variant of this algorithm or not, see [On the Convergence of Adam and
+ Beyond](https://huggingface.co/papers/1904.09237).
+ weight_decay_rate (`float`, *optional*, defaults to 0.0):
+ The weight decay to apply.
+ include_in_weight_decay (`list[str]`, *optional*):
+ List of the parameter names (or re patterns) to apply weight decay to. If none is passed, weight decay is
+ applied to all parameters by default (unless they are in `exclude_from_weight_decay`).
+ exclude_from_weight_decay (`list[str]`, *optional*):
+ List of the parameter names (or re patterns) to exclude from applying weight decay to. If a
+ `include_in_weight_decay` is passed, the names in it will supersede this list.
+ name (`str`, *optional*, defaults to `"AdamWeightDecay"`):
+ Optional name for the operations created when applying gradients.
+ kwargs (`dict[str, Any]`, *optional*):
+ Keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
+ norm; `clipvalue` is clip gradients by value, `decay` is included for backward compatibility to allow time
+ inverse decay of learning rate. `lr` is included for backward compatibility, recommended to use
+ `learning_rate` instead.
+ """
+
+ def __init__(
+ self,
+ learning_rate: Union[float, schedules.LearningRateSchedule] = 0.001,
+ beta_1: float = 0.9,
+ beta_2: float = 0.999,
+ epsilon: float = 1e-7,
+ amsgrad: bool = False,
+ weight_decay_rate: float = 0.0,
+ include_in_weight_decay: Optional[list[str]] = None,
+ exclude_from_weight_decay: Optional[list[str]] = None,
+ name: str = "AdamWeightDecay",
+ **kwargs,
+ ):
+ super().__init__(learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
+ self.weight_decay_rate = weight_decay_rate
+ self._include_in_weight_decay = include_in_weight_decay
+ self._exclude_from_weight_decay = exclude_from_weight_decay
+
+ @classmethod
+ def from_config(cls, config):
+ """Creates an optimizer from its config with WarmUp custom object."""
+ custom_objects = {"WarmUp": WarmUp}
+ return super().from_config(config, custom_objects=custom_objects)
+
+ def _prepare_local(self, var_device, var_dtype, apply_state):
+ super()._prepare_local(var_device, var_dtype, apply_state)
+ apply_state[(var_device, var_dtype)]["weight_decay_rate"] = tf.constant(
+ self.weight_decay_rate, name="adam_weight_decay_rate"
+ )
+
+ def _decay_weights_op(self, var, learning_rate, apply_state):
+ do_decay = self._do_use_weight_decay(var.name)
+ if do_decay:
+ return var.assign_sub(
+ learning_rate * var * apply_state[(var.device, var.dtype.base_dtype)]["weight_decay_rate"],
+ use_locking=self._use_locking,
+ )
+ return tf.no_op()
+
+ def apply_gradients(self, grads_and_vars, name=None, **kwargs):
+ grads, tvars = list(zip(*grads_and_vars))
+ return super().apply_gradients(zip(grads, tvars), name=name, **kwargs)
+
+ def _get_lr(self, var_device, var_dtype, apply_state):
+ """Retrieves the learning rate with the given state."""
+ if apply_state is None:
+ return self._decayed_lr_t[var_dtype], {}
+
+ apply_state = apply_state or {}
+ coefficients = apply_state.get((var_device, var_dtype))
+ if coefficients is None:
+ coefficients = self._fallback_apply_state(var_device, var_dtype)
+ apply_state[(var_device, var_dtype)] = coefficients
+
+ return coefficients["lr_t"], {"apply_state": apply_state}
+
+ def _resource_apply_dense(self, grad, var, apply_state=None):
+ lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
+ decay = self._decay_weights_op(var, lr_t, apply_state)
+ with tf.control_dependencies([decay]):
+ return super()._resource_apply_dense(grad, var, **kwargs)
+
+ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
+ lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
+ decay = self._decay_weights_op(var, lr_t, apply_state)
+ with tf.control_dependencies([decay]):
+ return super()._resource_apply_sparse(grad, var, indices, **kwargs)
+
+ def get_config(self):
+ config = super().get_config()
+ config.update({"weight_decay_rate": self.weight_decay_rate})
+ return config
+
+ def _do_use_weight_decay(self, param_name):
+ """Whether to use L2 weight decay for `param_name`."""
+ if self.weight_decay_rate == 0:
+ return False
+
+ if self._include_in_weight_decay:
+ for r in self._include_in_weight_decay:
+ if r in param_name:
+ return True
+
+ if self._exclude_from_weight_decay:
+ for r in self._exclude_from_weight_decay:
+ if r in param_name:
+ return False
+ return True
+
+
+# Extracted from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
+class GradientAccumulator:
+ """
+ Gradient accumulation utility. When used with a distribution strategy, the accumulator should be called in a
+ replica context. Gradients will be accumulated locally on each replica and without synchronization. Users should
+ then call `.gradients`, scale the gradients if required, and pass the result to `apply_gradients`.
+ """
+
+ # We use the ON_READ synchronization policy so that no synchronization is
+ # performed on assignment. To get the value, we call .value() which returns the
+ # value on the current replica without synchronization.
+
+ def __init__(self):
+ """Initializes the accumulator."""
+ self._gradients = []
+ self._accum_steps = None
+
+ @property
+ def step(self):
+ """Number of accumulated steps."""
+ if self._accum_steps is None:
+ self._accum_steps = tf.Variable(
+ tf.constant(0, dtype=tf.int64),
+ trainable=False,
+ synchronization=tf.VariableSynchronization.ON_READ,
+ aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
+ )
+
+ return self._accum_steps.value()
+
+ @property
+ def gradients(self):
+ """The accumulated gradients on the current replica."""
+ if not self._gradients:
+ raise ValueError("The accumulator should be called first to initialize the gradients")
+ return [gradient.value() if gradient is not None else gradient for gradient in self._gradients]
+
+ def __call__(self, gradients):
+ """Accumulates `gradients` on the current replica."""
+ if not self._gradients:
+ _ = self.step # Create the step variable.
+ self._gradients.extend(
+ [
+ tf.Variable(
+ tf.zeros_like(gradient),
+ trainable=False,
+ synchronization=tf.VariableSynchronization.ON_READ,
+ aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
+ )
+ if gradient is not None
+ else gradient
+ for gradient in gradients
+ ]
+ )
+ if len(gradients) != len(self._gradients):
+ raise ValueError(f"Expected {len(self._gradients)} gradients, but got {len(gradients)}")
+
+ for accum_gradient, gradient in zip(self._gradients, gradients):
+ if accum_gradient is not None and gradient is not None:
+ accum_gradient.assign_add(gradient)
+
+ self._accum_steps.assign_add(1)
+
+ def reset(self):
+ """Resets the accumulated gradients on the current replica."""
+ if not self._gradients:
+ return
+ self._accum_steps.assign(0)
+ for gradient in self._gradients:
+ if gradient is not None:
+ gradient.assign(tf.zeros_like(gradient))
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/processing_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/processing_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b40b6535f1b5448c8f97fa913aab224194f7121
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/processing_utils.py
@@ -0,0 +1,1782 @@
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processing saving/loading class for common processors.
+"""
+
+import bisect
+import copy
+import inspect
+import json
+import os
+import sys
+import typing
+import warnings
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Optional, TypedDict, TypeVar, Union
+
+import numpy as np
+import typing_extensions
+from huggingface_hub.errors import EntryNotFoundError
+
+from .audio_utils import AudioInput, load_audio
+from .dynamic_module_utils import custom_object_save
+from .feature_extraction_utils import BatchFeature
+from .image_utils import ChannelDimension, ImageInput, is_vision_available
+from .utils.chat_template_utils import render_jinja_template
+from .video_utils import VideoInput, VideoMetadata
+
+
+if is_vision_available():
+ from .image_utils import PILImageResampling
+
+
+from .tokenization_utils_base import (
+ PaddingStrategy,
+ PreTokenizedInput,
+ PreTrainedTokenizerBase,
+ TextInput,
+ TruncationStrategy,
+)
+from .utils import (
+ AUDIO_TOKENIZER_NAME,
+ CHAT_TEMPLATE_DIR,
+ CHAT_TEMPLATE_FILE,
+ LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE,
+ PROCESSOR_NAME,
+ PushToHubMixin,
+ TensorType,
+ cached_file,
+ copy_func,
+ direct_transformers_import,
+ download_url,
+ is_offline_mode,
+ is_remote_url,
+ is_torch_available,
+ list_repo_templates,
+ logging,
+)
+from .utils.deprecation import deprecate_kwarg
+
+
+if is_torch_available():
+ from .modeling_utils import PreTrainedAudioTokenizerBase
+
+
+logger = logging.get_logger(__name__)
+
+# type hinting: specifying the type of processor class that inherits from ProcessorMixin
+SpecificProcessorType = TypeVar("SpecificProcessorType", bound="ProcessorMixin")
+
+# Dynamically import the Transformers module to grab the attribute classes of the processor from their names.
+transformers_module = direct_transformers_import(Path(__file__).parent)
+
+
+AUTO_TO_BASE_CLASS_MAPPING = {
+ "AutoTokenizer": "PreTrainedTokenizerBase",
+ "AutoFeatureExtractor": "FeatureExtractionMixin",
+ "AutoImageProcessor": "ImageProcessingMixin",
+ "AutoVideoProcessor": "BaseVideoProcessor",
+}
+
+if sys.version_info >= (3, 11):
+ Unpack = typing.Unpack
+else:
+ Unpack = typing_extensions.Unpack
+
+
+class TextKwargs(TypedDict, total=False):
+ """
+ Keyword arguments for text processing. For extended documentation, check out tokenization_utils_base methods and
+ docstrings associated.
+
+ Attributes:
+ add_special_tokens (`bool`, *optional*)
+ Whether or not to add special tokens when encoding the sequences.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*)
+ Activates and controls padding.
+ truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*):
+ Activates and controls truncation.
+ max_length (`int`, *optional*):
+ Controls the maximum length to use by one of the truncation/padding parameters.
+ stride (`int`, *optional*):
+ If set, the overflowing tokens will contain some tokens from the end of the truncated sequence.
+ is_split_into_words (`bool`, *optional*):
+ Whether or not the input is already pre-tokenized.
+ pad_to_multiple_of (`int`, *optional*):
+ If set, will pad the sequence to a multiple of the provided value.
+ return_token_type_ids (`bool`, *optional*):
+ Whether to return token type IDs.
+ return_attention_mask (`bool`, *optional*):
+ Whether to return the attention mask.
+ return_overflowing_tokens (`bool`, *optional*):
+ Whether or not to return overflowing token sequences.
+ return_special_tokens_mask (`bool`, *optional*):
+ Whether or not to return special tokens mask information.
+ return_offsets_mapping (`bool`, *optional*):
+ Whether or not to return `(char_start, char_end)` for each token.
+ return_length (`bool`, *optional*):
+ Whether or not to return the lengths of the encoded inputs.
+ verbose (`bool`, *optional*):
+ Whether or not to print more information and warnings.
+ padding_side (`str`, *optional*):
+ The side on which padding will be applied.
+ return_mm_token_type_ids (`bool`, *optional*):
+ Whether to return multimodal token type ids indicating mm placeholder token positions.
+ """
+
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]]
+ text_target: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]
+ text_pair_target: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]]
+ add_special_tokens: Optional[bool]
+ padding: Union[bool, str, PaddingStrategy]
+ truncation: Union[bool, str, TruncationStrategy]
+ max_length: Optional[int]
+ stride: Optional[int]
+ is_split_into_words: Optional[bool]
+ pad_to_multiple_of: Optional[int]
+ return_token_type_ids: Optional[bool]
+ return_attention_mask: Optional[bool]
+ return_overflowing_tokens: Optional[bool]
+ return_special_tokens_mask: Optional[bool]
+ return_offsets_mapping: Optional[bool]
+ return_length: Optional[bool]
+ verbose: Optional[bool]
+ padding_side: Optional[str]
+ return_mm_token_type_ids: Optional[bool]
+
+
+class ImagesKwargs(TypedDict, total=False):
+ """
+ Keyword arguments for image processing. For extended documentation, check the appropriate ImageProcessor
+ class methods and docstrings.
+
+ Attributes:
+ do_resize (`bool`, *optional*):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*):
+ Resize the shorter side of the input to `size["shortest_edge"]`.
+ crop_size (`dict[str, int]`, *optional*):
+ Desired output size when applying center-cropping.
+ resample (`PILImageResampling`, *optional*):
+ Resampling filter to use if resizing the image.
+ do_rescale (`bool`, *optional*):
+ Whether to rescale the image by the specified scale `rescale_factor`.
+ rescale_factor (`int` or `float`, *optional*):
+ Scale factor to use if rescaling the image.
+ do_normalize (`bool`, *optional*):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*):
+ Mean to use if normalizing the image.
+ image_std (`float` or `list[float]`, *optional*):
+ Standard deviation to use if normalizing the image.
+ do_pad (`bool`, *optional*):
+ Whether to pad the image to the `(max_height, max_width)` of the images in the batch.
+ pad_size (`dict[str, int]`, *optional*):
+ The size `{"height": int, "width" int}` to pad the images to.
+ do_center_crop (`bool`, *optional*):
+ Whether to center crop the image.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image.
+ device (`str`, *optional*):
+ The device to use for processing (e.g. "cpu", "cuda"), only relevant for fast image processing.
+ """
+
+ do_resize: Optional[bool]
+ size: Optional[dict[str, int]]
+ crop_size: Optional[dict[str, int]]
+ resample: Optional[Union["PILImageResampling", int]]
+ do_rescale: Optional[bool]
+ rescale_factor: Optional[float]
+ do_normalize: Optional[bool]
+ image_mean: Optional[Union[float, list[float]]]
+ image_std: Optional[Union[float, list[float]]]
+ do_pad: Optional[bool]
+ pad_size: Optional[dict[str, int]]
+ do_center_crop: Optional[bool]
+ data_format: Optional[ChannelDimension]
+ input_data_format: Optional[Union[str, ChannelDimension]]
+ device: Optional[str]
+
+
+class VideosKwargs(TypedDict, total=False):
+ """
+ Keyword arguments for video processing.
+
+ Attributes:
+ do_convert_rgb (`bool`):
+ Whether to convert the video to RGB format.
+ do_resize (`bool`):
+ Whether to resize the video.
+ size (`dict[str, int]`, *optional*):
+ Resize the shorter side of the input to `size["shortest_edge"]`.
+ default_to_square (`bool`, *optional*, defaults to `self.default_to_square`):
+ Whether to default to a square when resizing, if size is an int.
+ resample (`PILImageResampling`, *optional*):
+ Resampling filter to use if resizing the video.
+ do_rescale (`bool`, *optional*):
+ Whether to rescale the video by the specified scale `rescale_factor`.
+ rescale_factor (`int` or `float`, *optional*):
+ Scale factor to use if rescaling the video.
+ do_normalize (`bool`, *optional*):
+ Whether to normalize the video.
+ image_mean (`float` or `list[float]`, *optional*):
+ Mean to use if normalizing the video.
+ image_std (`float` or `list[float]`, *optional*):
+ Standard deviation to use if normalizing the video.
+ do_center_crop (`bool`, *optional*):
+ Whether to center crop the video.
+ do_sample_frames (`bool`, *optional*):
+ Whether to sample frames from the video before processing or to process the whole video.
+ video_metadata (`Union[VideoMetadata, dict]`, *optional*):
+ Metadata of the video containing information about total duration, fps and total number of frames.
+ num_frames (`int`, *optional*):
+ Maximum number of frames to sample when `do_sample_frames=True`.
+ fps (`int` or `float`, *optional*):
+ Target frames to sample per second when `do_sample_frames=True`.
+ crop_size (`dict[str, int]`, *optional*):
+ Desired output size when applying center-cropping.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output video.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input video.
+ return_metadata (`ChannelDimension` or `str`, *optional*):
+ Whether to return video metadata or not.
+ """
+
+ do_convert_rgb: Optional[bool]
+ do_resize: Optional[bool]
+ size: Optional[dict[str, int]]
+ default_to_square: Optional[bool]
+ resample: Optional["PILImageResampling"]
+ do_rescale: Optional[bool]
+ rescale_factor: Optional[float]
+ do_normalize: Optional[bool]
+ image_mean: Optional[Union[float, list[float]]]
+ image_std: Optional[Union[float, list[float]]]
+ do_center_crop: Optional[bool]
+ crop_size: Optional[dict[str, int]]
+ data_format: Optional[ChannelDimension]
+ input_data_format: Optional[Union[str, ChannelDimension]]
+ device: Optional[str]
+ do_sample_frames: Optional[bool]
+ video_metadata: Optional[Union[VideoMetadata, dict]]
+ fps: Optional[Union[int, float]]
+ num_frames: Optional[int]
+ return_metadata: Optional[bool]
+
+
+class AudioKwargs(TypedDict, total=False):
+ """
+ Keyword arguments for audio processing.
+
+ Attributes:
+ sampling_rate (`int`, *optional*):
+ The sampling rate at which the `raw_speech` input was sampled.
+ raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
+ The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
+ values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
+ stereo, i.e. single float per timestep.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
+ index) among:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'`
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see above).
+ truncation (`bool`, *optional*):
+ Activates truncation to cut input sequences longer than *max_length* to *max_length*.
+ pad_to_multiple_of (`int`, *optional*):
+ If set, will pad the sequence to a multiple of the provided value.
+ return_attention_mask (`bool`, *optional*):
+ Whether or not [`~ASTFeatureExtractor.__call__`] should return `attention_mask`.
+ """
+
+ sampling_rate: Optional[int]
+ raw_speech: Optional[Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]]]
+ padding: Optional[Union[bool, str, PaddingStrategy]]
+ max_length: Optional[int]
+ truncation: Optional[bool]
+ pad_to_multiple_of: Optional[int]
+ return_attention_mask: Optional[bool]
+
+
+class CommonKwargs(TypedDict, total=False):
+ return_tensors: Optional[Union[str, TensorType]]
+
+
+class ProcessingKwargs(TypedDict, total=False):
+ """
+ Base class for kwargs passing to processors.
+ In case a model has specific kwargs that are not present in the base class or default values for existing keys,
+ it should have its own `ModelProcessorKwargs` class that inherits from `ProcessingKwargs` to provide:
+ 1) Additional typed keys and that this model requires to process inputs.
+ 2) Default values for existing keys under a `_defaults` attribute.
+ New keys have to be defined as follows to ensure type hinting is done correctly.
+
+ ```python
+ # adding a new image kwarg for this model
+ class ModelImagesKwargs(ImagesKwargs, total=False):
+ new_image_kwarg: Optional[bool]
+
+ class ModelProcessorKwargs(ProcessingKwargs, total=False):
+ images_kwargs: ModelImagesKwargs
+ _defaults = {
+ "images_kwargs: {
+ "new_image_kwarg": False,
+ }
+ "text_kwargs": {
+ "padding": "max_length",
+ },
+ }
+
+ ```
+
+ For Python 3.8 compatibility, when inheriting from this class and overriding one of the kwargs,
+ you need to manually update the __annotations__ dictionary. This can be done as follows:
+
+ ```python
+ class CustomProcessorKwargs(ProcessingKwargs, total=False):
+ images_kwargs: CustomImagesKwargs
+
+ CustomProcessorKwargs.__annotations__["images_kwargs"] = CustomImagesKwargs # python 3.8 compatibility
+ ```python
+
+ """
+
+ _defaults = {}
+
+ common_kwargs: CommonKwargs = {
+ **CommonKwargs.__annotations__,
+ }
+ text_kwargs: TextKwargs = {
+ **TextKwargs.__annotations__,
+ }
+ images_kwargs: ImagesKwargs = {
+ **ImagesKwargs.__annotations__,
+ }
+ videos_kwargs: VideosKwargs = {
+ **VideosKwargs.__annotations__,
+ }
+ audio_kwargs: AudioKwargs = {
+ **AudioKwargs.__annotations__,
+ }
+
+
+class TokenizerChatTemplateKwargs(TypedDict, total=False):
+ """
+ Keyword arguments for tokenizer's `apply_chat_template`, when it is called from within a processor.
+
+ tools (`list[Dict]`, *optional*):
+ A list of tools (callable functions) that will be accessible to the model. If the template does not
+ support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
+ giving the name, description and argument types for the tool. See our
+ [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
+ for more information.
+ documents (`list[dict[str, str]]`, *optional*):
+ A list of dicts representing documents that will be accessible to the model if it is performing RAG
+ (retrieval-augmented generation). If the template does not support RAG, this argument will have no
+ effect. We recommend that each document should be a dict containing "title" and "text" keys. Please
+ see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG)
+ for examples of passing documents with chat templates.
+ add_generation_prompt (bool, *optional*):
+ If this is set, a prompt with the token(s) that indicate
+ the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model.
+ Note that this argument will be passed to the chat template, and so it must be supported in the
+ template for this argument to have any effect.
+ continue_final_message (bool, *optional*):
+ If this is set, the chat will be formatted so that the final
+ message in the chat is open-ended, without any EOS tokens. The model will continue this message
+ rather than starting a new one. This allows you to "prefill" part of
+ the model's response for it. Cannot be used at the same time as `add_generation_prompt`.
+ return_assistant_tokens_mask (`bool`, defaults to `False`):
+ Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant,
+ the mask will contain 1. For user and system tokens, the mask will contain 0.
+ This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
+ """
+
+ tools: Optional[list[dict]] = None
+ documents: Optional[list[dict[str, str]]] = None
+ add_generation_prompt: Optional[bool] = False
+ continue_final_message: Optional[bool] = False
+ return_assistant_tokens_mask: Optional[bool] = False
+
+
+class ChatTemplateLoadKwargs(TypedDict, total=False):
+ """
+ Keyword arguments used to load multimodal data in processor chat templates.
+
+ num_frames (`int`, *optional*):
+ Number of frames to sample uniformly. If not passed, the whole video is loaded.
+ load_audio_from_video (`bool`, *optional*):
+ Whether to use the audio track of input video. If `True` the audio track will be loaded and passed to the
+ processor. This flag has no effect if the model doesn't support audio modality.
+ """
+
+ sampling_rate: Optional[int] = 16_000
+ load_audio_from_video: Optional[bool] = False
+
+
+class ProcessorChatTemplateKwargs(ChatTemplateLoadKwargs, TokenizerChatTemplateKwargs, total=False):
+ """
+ Keyword arguments for processor's `apply_chat_template`.
+
+ tokenize (`bool`, *optional*, defaults to `False`):
+ Whether to tokenize the output or not.
+ return_dict (`bool`, defaults to `False`):
+ Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
+ """
+
+ tokenize: Optional[bool] = False
+ return_dict: Optional[bool] = False
+
+
+class AllKwargsForChatTemplate(TypedDict, total=False):
+ processor_kwargs: ProcessingKwargs
+ mm_load_kwargs: ChatTemplateLoadKwargs
+ template_kwargs: ProcessorChatTemplateKwargs
+
+
+@dataclass
+class MultiModalData:
+ """
+ Dataclass that holds extra useful data for processing
+ multimodal data. Processors currently cannot return keys,
+ unless it is used in model's forward. Thus we have helper
+ methods that calculate and return useful data from processing
+ input multimodals (images/videos).
+ Note that this dataclass is aimed to be used only in vLLM
+ and we might change its API in the future.
+ """
+
+ num_image_tokens: Optional[list[int]] = None
+ num_video_tokens: Optional[list[int]] = None
+ num_audio_tokens: Optional[list[int]] = None
+ num_image_patches: Optional[list[int]] = None
+
+ def __contains__(self, key):
+ return hasattr(self, key) and getattr(self, key) is not None
+
+ def __getitem__(self, key):
+ if hasattr(self, key):
+ return getattr(self, key)
+ raise AttributeError(f"{self.__class__.__name__} has no attribute {key}")
+
+
+class ProcessorMixin(PushToHubMixin):
+ """
+ This is a mixin used to provide saving/loading functionality for all processor classes.
+ """
+
+ attributes = ["feature_extractor", "tokenizer"]
+ optional_attributes = ["chat_template", "audio_tokenizer"]
+ optional_call_args: list[str] = []
+ # Names need to be attr_class for attr in attributes
+ feature_extractor_class = None
+ tokenizer_class = None
+ _auto_class = None
+ valid_processor_kwargs = ProcessingKwargs
+
+ # args have to match the attributes class attribute
+ def __init__(self, *args, **kwargs):
+ # First, extract optional attributes from kwargs if present
+ # Optional attributes can never be positional arguments
+ for optional_attribute in self.optional_attributes:
+ optional_attribute_value = kwargs.pop(optional_attribute, None)
+ setattr(self, optional_attribute, optional_attribute_value)
+
+ # Check audio tokenizer for its class but do not treat it as attr to avoid saving weights
+ if optional_attribute == "audio_tokenizer" and optional_attribute_value is not None:
+ proper_class = self.check_argument_for_proper_class(optional_attribute, optional_attribute_value)
+
+ if not (is_torch_available() and isinstance(optional_attribute_value, PreTrainedAudioTokenizerBase)):
+ raise ValueError(
+ f"Tried to use `{proper_class}` for audio tokenization. However, this class is not"
+ " registered for audio tokenization."
+ )
+
+ # Sanitize args and kwargs
+ for key in kwargs:
+ if key not in self.attributes:
+ raise TypeError(f"Unexpected keyword argument {key}.")
+ for arg, attribute_name in zip(args, self.attributes):
+ if attribute_name in kwargs:
+ raise TypeError(f"Got multiple values for argument {attribute_name}.")
+ else:
+ kwargs[attribute_name] = arg
+
+ if len(kwargs) != len(self.attributes):
+ raise ValueError(
+ f"This processor requires {len(self.attributes)} arguments: {', '.join(self.attributes)}. Got "
+ f"{len(args)} arguments instead."
+ )
+
+ # Check each arg is of the proper class (this will also catch a user initializing in the wrong order)
+ for attribute_name, arg in kwargs.items():
+ self.check_argument_for_proper_class(attribute_name, arg)
+ setattr(self, attribute_name, arg)
+
+ def __call__(
+ self,
+ images: Optional[ImageInput] = None,
+ text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None,
+ videos: Optional[VideoInput] = None,
+ audio: Optional[AudioInput] = None,
+ **kwargs: Unpack[ProcessingKwargs],
+ ):
+ """
+ Main method to prepare for model inputs. This method forwards the each modality argument to its own processor
+ along with `kwargs`. Please refer to the docstring of the each processor attributes for more information.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`TextInput`, `PreTokenizedInput`, `list[TextInput]`, `list[PreTokenizedInput]`, *optional*):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The video or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
+ tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
+ audio (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The audio or batch of audio to be prepared. Each audio can be a NumPy array or PyTorch
+ tensor.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] object with processed inputs in a dict format.
+ """
+ if images is None and text is None and videos is None and audio is None:
+ raise ValueError(f"You need to provide at least one input to call {self.__class__.__name__}")
+
+ kwargs = self._merge_kwargs(
+ self.valid_processor_kwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs if hasattr(self, "tokenizer") else {},
+ **kwargs,
+ )
+
+ attribute_to_kwargs = {
+ "tokenizer": (text, "text_kwargs"),
+ "image_processor": (images, "images_kwargs"),
+ "video_processor": (videos, "videos_kwargs"),
+ "feature_extractor": (audio, "audio_kwargs"),
+ }
+ outputs = {}
+ for attribute_name in self.attributes:
+ attribute = getattr(self, attribute_name, None)
+ input_data, input_kwargs = attribute_to_kwargs[attribute_name]
+ if input_data is not None and attribute is not None:
+ attribute_output = attribute(input_data, **kwargs[input_kwargs])
+ outputs.update(attribute_output)
+
+ return BatchFeature(outputs)
+
+ def check_argument_for_proper_class(self, argument_name, argument):
+ """
+ Checks the passed argument's class against the expected transformers class. In case of an unexpected
+ mismatch between expected and actual class, an error is raise. Otherwise, the proper retrieved class
+ is returned.
+ """
+ class_name = getattr(self, f"{argument_name}_class")
+ # Nothing is ever going to be an instance of "AutoXxx", in that case we check the base class.
+ class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name)
+ if isinstance(class_name, tuple):
+ proper_class = tuple(self.get_possibly_dynamic_module(n) for n in class_name if n is not None)
+ else:
+ proper_class = self.get_possibly_dynamic_module(class_name)
+
+ if not isinstance(argument, proper_class):
+ raise TypeError(
+ f"Received a {type(argument).__name__} for argument {argument_name}, but a {class_name} was expected."
+ )
+
+ return proper_class
+
+ def to_dict(self, legacy_serialization=True) -> dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary.
+
+ Returns:
+ `dict[str, Any]`: Dictionary of all the attributes that make up this processor instance.
+ """
+ output = copy.deepcopy(self.__dict__)
+
+ # Get the kwargs in `__init__`.
+ sig = inspect.signature(self.__init__)
+ # Only save the attributes that are presented in the kwargs of `__init__`.
+ attrs_to_save = list(sig.parameters)
+ # extra attributes to be kept
+ attrs_to_save += ["auto_map"]
+
+ if legacy_serialization:
+ # Don't save attributes like `tokenizer`, `image processor` etc. in processor config if `legacy=True`
+ attrs_to_save = [x for x in attrs_to_save if x not in self.__class__.attributes]
+
+ if "tokenizer" in output:
+ del output["tokenizer"]
+ if "qformer_tokenizer" in output:
+ del output["qformer_tokenizer"]
+ if "protein_tokenizer" in output:
+ del output["protein_tokenizer"]
+ if "chat_template" in output:
+ del output["chat_template"]
+
+ def cast_array_to_list(dictionary):
+ """
+ Numpy arrays are not serialiazable but can be in pre-processing dicts.
+ This function casts arrays to list, recusring through the nested configs as well.
+ """
+ for key, value in dictionary.items():
+ if isinstance(value, np.ndarray):
+ dictionary[key] = value.tolist()
+ elif isinstance(value, dict):
+ dictionary[key] = cast_array_to_list(value)
+ return dictionary
+
+ # Serialize attributes as a dict
+ output = {
+ k: v.to_dict() if isinstance(v, PushToHubMixin) else v
+ for k, v in output.items()
+ if (
+ k in attrs_to_save # keep all attributes that have to be serialized
+ and v.__class__.__name__ != "BeamSearchDecoderCTC" # remove attributes with that are objects
+ and (
+ (legacy_serialization and not isinstance(v, PushToHubMixin)) or not legacy_serialization
+ ) # remove `PushToHubMixin` objects
+ )
+ }
+ output = cast_array_to_list(output)
+
+ # Special case, add `audio_tokenizer` dict which points to model weights and path
+ if not legacy_serialization and "audio_tokenizer" in output:
+ audio_tokenizer_dict = {
+ "audio_tokenizer_class": self.audio_tokenizer.__class__.__name__,
+ "audio_tokenizer_name_or_path": self.audio_tokenizer.name_or_path,
+ }
+ # Update or overwrite, what do audio tokenizers expect when loading?
+ output["audio_tokenizer"] = audio_tokenizer_dict
+
+ output["processor_class"] = self.__class__.__name__
+
+ return output
+
+ def to_json_string(self, legacy_serialization=True) -> str:
+ """
+ Serializes this instance to a JSON string.
+
+ Returns:
+ `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
+ """
+ dictionary = self.to_dict(legacy_serialization=legacy_serialization)
+
+ return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike], legacy_serialization=True):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this processor instance's parameters will be saved.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string(legacy_serialization=legacy_serialization))
+
+ def __repr__(self):
+ attributes_repr = [f"- {name}: {repr(getattr(self, name))}" for name in self.attributes]
+ attributes_repr = "\n".join(attributes_repr)
+ return f"{self.__class__.__name__}:\n{attributes_repr}\n\n{self.to_json_string()}"
+
+ def save_pretrained(self, save_directory, push_to_hub: bool = False, legacy_serialization: bool = True, **kwargs):
+ """
+ Saves the attributes of this processor (feature extractor, tokenizer...) in the specified directory so that it
+ can be reloaded using the [`~ProcessorMixin.from_pretrained`] method.
+
+
+
+ This class method is simply calling [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] and
+ [`~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`]. Please refer to the docstrings of the
+ methods above for more information.
+
+
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
+ be created if it does not exist).
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ legacy_serialization (`bool`, *optional*, defaults to `True`):
+ Whether or not to save processor attributes in separate config files (legacy) or in processor's config
+ file as a nested dict. Saving all attributes in a single dict will become the default in future versions.
+ Set to `legacy_serialization=True` until then.
+ kwargs (`dict[str, Any]`, *optional*):
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ """
+ use_auth_token = kwargs.pop("use_auth_token", None)
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if kwargs.get("token") is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ kwargs["token"] = use_auth_token
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ if push_to_hub:
+ commit_message = kwargs.pop("commit_message", None)
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+ repo_id = self._create_repo(repo_id, **kwargs)
+ files_timestamps = self._get_files_timestamps(save_directory)
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
+ # loaded from the Hub.
+ if self._auto_class is not None:
+ attrs = [getattr(self, attribute_name) for attribute_name in self.attributes]
+ configs = [(a.init_kwargs if isinstance(a, PreTrainedTokenizerBase) else a) for a in attrs]
+ configs.append(self)
+ custom_object_save(self, save_directory, config=configs)
+
+ save_jinja_files = kwargs.get("save_jinja_files", True)
+
+ for attribute_name in self.attributes:
+ # Save the tokenizer in its own vocab file. The other attributes are saved as part of `processor_config.json`
+ if attribute_name == "tokenizer":
+ attribute = getattr(self, attribute_name)
+ if hasattr(attribute, "_set_processor_class"):
+ attribute._set_processor_class(self.__class__.__name__)
+
+ # Propagate save_jinja_files to tokenizer to ensure we don't get conflicts
+ attribute.save_pretrained(save_directory, save_jinja_files=save_jinja_files)
+ elif legacy_serialization:
+ attribute = getattr(self, attribute_name)
+ # Include the processor class in attribute config so this processor can then be reloaded with `AutoProcessor` API.
+ if hasattr(attribute, "_set_processor_class"):
+ attribute._set_processor_class(self.__class__.__name__)
+ attribute.save_pretrained(save_directory)
+
+ if self._auto_class is not None:
+ # We added an attribute to the init_kwargs of the tokenizers, which needs to be cleaned up.
+ for attribute_name in self.attributes:
+ attribute = getattr(self, attribute_name)
+ if isinstance(attribute, PreTrainedTokenizerBase):
+ del attribute.init_kwargs["auto_map"]
+
+ # If we save using the predefined names, we can load using `from_pretrained`
+ # plus we save chat_template in its own file
+ output_processor_file = os.path.join(save_directory, PROCESSOR_NAME)
+ output_chat_template_file_jinja = os.path.join(save_directory, CHAT_TEMPLATE_FILE)
+ output_chat_template_file_legacy = os.path.join(
+ save_directory, LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE
+ ) # Legacy filename
+ chat_template_dir = os.path.join(save_directory, CHAT_TEMPLATE_DIR)
+
+ # Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict`
+ # to avoid serializing chat template in json config file. So let's get it from `self` directly
+ if self.chat_template is not None:
+ save_jinja_files = kwargs.get("save_jinja_files", True)
+ is_single_template = isinstance(self.chat_template, str)
+ if save_jinja_files and is_single_template:
+ # New format for single templates is to save them as chat_template.jinja
+ with open(output_chat_template_file_jinja, "w", encoding="utf-8") as f:
+ f.write(self.chat_template)
+ logger.info(f"chat template saved in {output_chat_template_file_jinja}")
+ elif save_jinja_files and not is_single_template:
+ # New format for multiple templates is to save the default as chat_template.jinja
+ # and the other templates in the chat_templates/ directory
+ for template_name, template in self.chat_template.items():
+ if template_name == "default":
+ with open(output_chat_template_file_jinja, "w", encoding="utf-8") as f:
+ f.write(self.chat_template["default"])
+ logger.info(f"chat template saved in {output_chat_template_file_jinja}")
+ else:
+ os.makedirs(chat_template_dir, exist_ok=True)
+ template_filepath = os.path.join(chat_template_dir, f"{template_name}.jinja")
+ with open(template_filepath, "w", encoding="utf-8") as f:
+ f.write(template)
+ logger.info(f"chat template saved in {template_filepath}")
+ elif is_single_template:
+ # Legacy format for single templates: Put them in chat_template.json
+ chat_template_json_string = (
+ json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n"
+ )
+ with open(output_chat_template_file_legacy, "w", encoding="utf-8") as writer:
+ writer.write(chat_template_json_string)
+ logger.info(f"chat template saved in {output_chat_template_file_legacy}")
+ elif self.chat_template is not None:
+ # At this point we have multiple templates in the legacy format, which is not supported
+ # chat template dicts are saved to chat_template.json as lists of dicts with fixed key names.
+ raise ValueError(
+ "Multiple chat templates are not supported in the legacy format. Please save them as "
+ "separate files using the `save_jinja_files` argument."
+ )
+
+ if legacy_serialization:
+ output_audio_tokenizer_file = os.path.join(save_directory, AUDIO_TOKENIZER_NAME)
+ processor_dict = self.to_dict()
+
+ # For now, let's not save to `processor_config.json` if the processor doesn't have extra attributes and
+ # `auto_map` is not specified.
+ if set(processor_dict.keys()) != {"processor_class"}:
+ self.to_json_file(output_processor_file)
+ logger.info(f"processor saved in {output_processor_file}")
+
+ if set(processor_dict.keys()) == {"processor_class"}:
+ return_files = []
+ else:
+ return_files = [output_processor_file]
+
+ if self.audio_tokenizer is not None:
+ audio_tokenizer_class = self.audio_tokenizer.__class__.__name__
+ audio_tokenizer_name_or_path = self.audio_tokenizer.name_or_path
+ audio_tokenizer_dict = {
+ "audio_tokenizer_class": audio_tokenizer_class,
+ "audio_tokenizer_name_or_path": audio_tokenizer_name_or_path,
+ }
+ audio_tokenizer_json = json.dumps(audio_tokenizer_dict, indent=2, sort_keys=True) + "\n"
+ with open(output_audio_tokenizer_file, "w", encoding="utf-8") as writer:
+ writer.write(audio_tokenizer_json)
+
+ # Create a unified `preprocessor_config.json` and save all attributes as a composite config, except for tokenizers
+ # NOTE: this will become the default way to save all processor attrbiutes in future versions. Toggled off for now to give
+ # us time for smoother transition
+ else:
+ self.to_json_file(output_processor_file, legacy_serialization=False)
+ logger.info(f"processor saved in {output_processor_file}")
+ return_files = [output_processor_file]
+
+ if push_to_hub:
+ self._upload_modified_files(
+ save_directory,
+ repo_id,
+ files_timestamps,
+ commit_message=commit_message,
+ token=kwargs.get("token"),
+ )
+
+ return return_files
+
+ @classmethod
+ def get_processor_dict(
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
+ """
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
+ processor of type [`~processing_utils.ProcessingMixin`] using `from_args_and_dict`.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
+
+ Returns:
+ `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the processor object.
+ """
+ # holding a copy for optionally loading the audio tokenizer (if available)
+ audio_tokenizer_kwargs = copy.deepcopy(kwargs)
+
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", None)
+ proxies = kwargs.pop("proxies", None)
+ token = kwargs.pop("token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", "")
+
+ from_pipeline = kwargs.pop("_from_pipeline", None)
+ from_auto_class = kwargs.pop("_from_auto", False)
+
+ user_agent = {"file_type": "processor", "from_auto_class": from_auto_class}
+ if from_pipeline is not None:
+ user_agent["using_pipeline"] = from_pipeline
+
+ if is_offline_mode() and not local_files_only:
+ logger.info("Offline mode: forcing local_files_only=True")
+ local_files_only = True
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if os.path.isdir(pretrained_model_name_or_path):
+ processor_file = os.path.join(pretrained_model_name_or_path, PROCESSOR_NAME)
+
+ additional_chat_template_files = {}
+ resolved_additional_chat_template_files = {}
+ if os.path.isfile(pretrained_model_name_or_path):
+ resolved_processor_file = pretrained_model_name_or_path
+ # can't load chat-template and audio tokenizer when given a file as pretrained_model_name_or_path
+ resolved_chat_template_file = None
+ resolved_raw_chat_template_file = None
+ resolved_audio_tokenizer_file = None
+ is_local = True
+ elif is_remote_url(pretrained_model_name_or_path):
+ processor_file = pretrained_model_name_or_path
+ resolved_processor_file = download_url(pretrained_model_name_or_path)
+ # can't load chat-template and audio tokenizer when given a file url as pretrained_model_name_or_path
+ resolved_chat_template_file = None
+ resolved_raw_chat_template_file = None
+ resolved_audio_tokenizer_file = None
+ else:
+ if is_local:
+ template_dir = Path(pretrained_model_name_or_path, CHAT_TEMPLATE_DIR)
+ if template_dir.is_dir():
+ for template_file in template_dir.glob("*.jinja"):
+ template_name = template_file.stem
+ additional_chat_template_files[template_name] = f"{CHAT_TEMPLATE_DIR}/{template_file.name}"
+ else:
+ try:
+ for template in list_repo_templates(
+ pretrained_model_name_or_path,
+ local_files_only=local_files_only,
+ revision=revision,
+ cache_dir=cache_dir,
+ token=token,
+ ):
+ additional_chat_template_files[template] = f"{CHAT_TEMPLATE_DIR}/{template}.jinja"
+ except EntryNotFoundError:
+ pass # No template dir means no template files
+ processor_file = PROCESSOR_NAME
+
+ try:
+ # Load from local folder or from cache or download from model Hub and cache
+ resolved_processor_file = cached_file(
+ pretrained_model_name_or_path,
+ processor_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ token=token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ _raise_exceptions_for_missing_entries=False,
+ )
+
+ # chat_template.json is a legacy file used by the processor class
+ # a raw chat_template.jinja is preferred in future
+ resolved_chat_template_file = cached_file(
+ pretrained_model_name_or_path,
+ LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ token=token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ _raise_exceptions_for_missing_entries=False,
+ )
+
+ resolved_raw_chat_template_file = cached_file(
+ pretrained_model_name_or_path,
+ CHAT_TEMPLATE_FILE,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ token=token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ _raise_exceptions_for_missing_entries=False,
+ )
+
+ resolved_additional_chat_template_files = {
+ template_name: cached_file(
+ pretrained_model_name_or_path,
+ template_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ token=token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ _raise_exceptions_for_missing_entries=False,
+ )
+ for template_name, template_file in additional_chat_template_files.items()
+ }
+
+ resolved_audio_tokenizer_file = cached_file(
+ pretrained_model_name_or_path,
+ AUDIO_TOKENIZER_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ token=token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ _raise_exceptions_for_missing_entries=False,
+ )
+ except OSError:
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
+ # the original exception.
+ raise
+ except Exception:
+ # For any other exception, we throw a generic error.
+ raise OSError(
+ f"Can't load processor for '{pretrained_model_name_or_path}'. If you were trying to load"
+ " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+ f" directory containing a {PROCESSOR_NAME} file"
+ )
+
+ # Add chat template as kwarg before returning because most models don't have processor config
+ if resolved_chat_template_file is not None:
+ # This is the legacy path
+ with open(resolved_chat_template_file, encoding="utf-8") as reader:
+ chat_template_json = json.loads(reader.read())
+ chat_templates = {"default": chat_template_json["chat_template"]}
+ if resolved_additional_chat_template_files:
+ raise ValueError(
+ "Cannot load chat template due to conflicting files - this checkpoint combines "
+ "a legacy chat_template.json file with separate template files, which is not "
+ "supported. To resolve this error, replace the legacy chat_template.json file "
+ "with a modern chat_template.jinja file."
+ )
+ else:
+ chat_templates = {
+ template_name: open(template_file, "r", encoding="utf-8").read()
+ for template_name, template_file in resolved_additional_chat_template_files.items()
+ }
+ if resolved_raw_chat_template_file is not None:
+ with open(resolved_raw_chat_template_file, "r", encoding="utf-8") as reader:
+ chat_templates["default"] = reader.read()
+ if isinstance(chat_templates, dict) and "default" in chat_templates and len(chat_templates) == 1:
+ chat_templates = chat_templates["default"] # Flatten when we just have a single template/file
+
+ if chat_templates:
+ kwargs["chat_template"] = chat_templates
+
+ # Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not
+ # updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict.
+ # (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception)
+ # However, for models added in the future, we won't get the expected error if this file is missing.
+ if resolved_processor_file is None:
+ # In any case we need to pass `chat_template` if it is available
+ processor_dict = {}
+ else:
+ try:
+ # Load processor dict
+ with open(resolved_processor_file, encoding="utf-8") as reader:
+ text = reader.read()
+ processor_dict = json.loads(text)
+
+ except json.JSONDecodeError:
+ raise OSError(
+ f"It looks like the config file at '{resolved_processor_file}' is not a valid JSON file."
+ )
+
+ if is_local:
+ logger.info(f"loading configuration file {resolved_processor_file}")
+ else:
+ logger.info(f"loading configuration file {processor_file} from cache at {resolved_processor_file}")
+
+ if "chat_template" in processor_dict and processor_dict["chat_template"] is not None:
+ logger.warning_once(
+ "Chat templates should be in a 'chat_template.jinja' file but found key='chat_template' "
+ "in the processor's config. Make sure to move your template to its own file."
+ )
+
+ if "chat_template" in kwargs:
+ processor_dict["chat_template"] = kwargs.pop("chat_template")
+
+ # Audio tokenizer needs to load the model checkpoint first, because the saved
+ # json file contains only references to the model path and repo id
+ if resolved_audio_tokenizer_file is not None or "audio_tokenizer" in processor_dict:
+ if resolved_audio_tokenizer_file is not None:
+ reader = open(resolved_audio_tokenizer_file, "r", encoding="utf-8")
+ audio_tokenizer_dict = reader.read()
+ audio_tokenizer_dict = json.loads(audio_tokenizer_dict)
+ else:
+ audio_tokenizer_dict = processor_dict["audio_tokenizer"]
+
+ audio_tokenizer_class = cls.get_possibly_dynamic_module(audio_tokenizer_dict["audio_tokenizer_class"])
+ audio_tokenizer_path = audio_tokenizer_dict["audio_tokenizer_name_or_path"]
+ processor_dict["audio_tokenizer"] = audio_tokenizer_class.from_pretrained(
+ audio_tokenizer_path, **audio_tokenizer_kwargs
+ )
+
+ # Pop attributes if saved in a single processor dict, they are loaded in `_get_arguments_from_pretrained`
+ for attribute in cls.attributes:
+ processor_dict.pop(attribute, None)
+
+ return processor_dict, kwargs
+
+ @classmethod
+ def from_args_and_dict(cls, args, processor_dict: dict[str, Any], **kwargs):
+ """
+ Instantiates a type of [`~processing_utils.ProcessingMixin`] from a Python dictionary of parameters.
+
+ Args:
+ processor_dict (`dict[str, Any]`):
+ Dictionary that will be used to instantiate the processor object. Such a dictionary can be
+ retrieved from a pretrained checkpoint by leveraging the
+ [`~processing_utils.ProcessingMixin.to_dict`] method.
+ kwargs (`dict[str, Any]`):
+ Additional parameters from which to initialize the processor object.
+
+ Returns:
+ [`~processing_utils.ProcessingMixin`]: The processor object instantiated from those
+ parameters.
+ """
+ processor_dict = processor_dict.copy()
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
+
+ # We have to pop up some unused (but specific) kwargs and then validate that it doesn't contain unused kwargs
+ # If we don't pop, some specific kwargs will raise a warning
+ if "processor_class" in processor_dict:
+ del processor_dict["processor_class"]
+
+ if "auto_map" in processor_dict:
+ del processor_dict["auto_map"]
+
+ # override processor_dict with given kwargs
+ processor_dict.update(kwargs)
+
+ # check if there is an overlap between args and processor_dict
+ accepted_args_and_kwargs = cls.__init__.__code__.co_varnames[: cls.__init__.__code__.co_argcount][1:]
+
+ # validate both processor_dict and given kwargs
+ unused_kwargs, valid_kwargs = cls.validate_init_kwargs(
+ processor_config=processor_dict, valid_kwargs=accepted_args_and_kwargs
+ )
+
+ # update args that are already in processor_dict to avoid duplicate arguments
+ args_to_update = {
+ i: valid_kwargs.pop(arg)
+ for i, arg in enumerate(accepted_args_and_kwargs)
+ if (arg in valid_kwargs and i < len(args))
+ }
+ args = [args_to_update.get(i, arg) for i, arg in enumerate(args)]
+
+ # instantiate processor with used (and valid) kwargs only
+ processor = cls(*args, **valid_kwargs)
+
+ logger.info(f"Processor {processor}")
+ if return_unused_kwargs:
+ return processor, unused_kwargs
+ else:
+ return processor
+
+ def _merge_kwargs(
+ self,
+ ModelProcessorKwargs: ProcessingKwargs,
+ tokenizer_init_kwargs: Optional[dict] = None,
+ **kwargs,
+ ) -> dict[str, dict]:
+ """
+ Method to merge dictionaries of kwargs cleanly separated by modality within a Processor instance.
+ The order of operations is as follows:
+ 1) kwargs passed as before have highest priority to preserve BC.
+ ```python
+ high_priority_kwargs = {"crop_size" = {"height": 222, "width": 222}, "padding" = "max_length"}
+ processor(..., **high_priority_kwargs)
+ ```
+ 2) kwargs passed as modality-specific kwargs have second priority. This is the recommended API.
+ ```python
+ processor(..., text_kwargs={"padding": "max_length"}, images_kwargs={"crop_size": {"height": 222, "width": 222}}})
+ ```
+ 3) kwargs passed during instantiation of a modality processor have fourth priority.
+ ```python
+ tokenizer = tokenizer_class(..., {"padding": "max_length"})
+ image_processor = image_processor_class(...)
+ processor(tokenizer, image_processor) # will pass max_length unless overridden by kwargs at call
+ ```
+ 4) defaults kwargs specified at processor level have lowest priority.
+ ```python
+ class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwargs, total=False):
+ _defaults = {
+ "text_kwargs": {
+ "padding": "max_length",
+ "max_length": 64,
+ },
+ }
+ ```
+ Args:
+ ModelProcessorKwargs (`ProcessingKwargs`):
+ Typed dictionary of kwargs specifically required by the model passed.
+ tokenizer_init_kwargs (`Dict`, *optional*):
+ Dictionary of kwargs the tokenizer was instantiated with and need to take precedence over defaults.
+
+ Returns:
+ output_kwargs (`Dict`):
+ Dictionary of per-modality kwargs to be passed to each modality-specific processor.
+
+ """
+ # Initialize dictionaries
+ output_kwargs = {
+ "text_kwargs": {},
+ "images_kwargs": {},
+ "audio_kwargs": {},
+ "videos_kwargs": {},
+ "common_kwargs": {},
+ }
+
+ default_kwargs = {
+ "text_kwargs": {},
+ "images_kwargs": {},
+ "audio_kwargs": {},
+ "videos_kwargs": {},
+ "common_kwargs": {},
+ }
+
+ possible_modality_keywords = {"text", "audio", "videos", "images"}
+ used_keys = set()
+
+ # get defaults from set model processor kwargs if they exist
+ for modality in default_kwargs:
+ default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy()
+ # update defaults with arguments from tokenizer init
+ for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__:
+ # init with tokenizer init kwargs if necessary
+ if tokenizer_init_kwargs is not None and modality_key in tokenizer_init_kwargs:
+ value = (
+ getattr(self.tokenizer, modality_key)
+ if hasattr(self.tokenizer, modality_key)
+ else tokenizer_init_kwargs[modality_key]
+ )
+ default_kwargs[modality][modality_key] = value
+ # now defaults kwargs are updated with the tokenizers defaults.
+ # pass defaults to output dictionary
+ output_kwargs.update(default_kwargs)
+
+ # update modality kwargs with passed kwargs
+ non_modality_kwargs = set(kwargs) - set(output_kwargs)
+ for modality, output_kwarg in output_kwargs.items():
+ for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__:
+ # check if we received a structured kwarg dict or not to handle it correctly
+ if modality in kwargs:
+ kwarg_value = kwargs[modality].pop(modality_key, "__empty__")
+ # check if this key was passed as a flat kwarg.
+ if kwarg_value != "__empty__" and modality_key in non_modality_kwargs:
+ raise ValueError(
+ f"Keyword argument {modality_key} was passed two times:\n"
+ f"in a dictionary for {modality} and as a **kwarg."
+ )
+ elif modality_key in kwargs:
+ # we get a modality_key instead of popping it because modality-specific processors
+ # can have overlapping kwargs
+ kwarg_value = kwargs.get(modality_key, "__empty__")
+ else:
+ kwarg_value = "__empty__"
+ if not isinstance(kwarg_value, str) or kwarg_value != "__empty__":
+ output_kwarg[modality_key] = kwarg_value
+ used_keys.add(modality_key)
+
+ # Determine if kwargs is a flat dictionary or contains nested dictionaries
+ if any(key in default_kwargs for key in kwargs):
+ # kwargs is dictionary-based, and some keys match modality names
+ for modality, subdict in kwargs.items():
+ if modality in default_kwargs:
+ for subkey, subvalue in subdict.items():
+ if subkey not in used_keys:
+ output_kwargs[modality][subkey] = subvalue
+ used_keys.add(subkey)
+ else:
+ # kwargs is a flat dictionary
+ for key, kwarg in kwargs.items():
+ if key not in used_keys:
+ if key in ModelProcessorKwargs.__annotations__["common_kwargs"].__annotations__:
+ output_kwargs["common_kwargs"][key] = kwarg
+ elif key not in possible_modality_keywords:
+ logger.warning_once(
+ f"Keyword argument `{key}` is not a valid argument for this processor and will be ignored."
+ )
+
+ # all modality-specific kwargs are updated with common kwargs
+ for kwarg in output_kwargs.values():
+ kwarg.update(output_kwargs["common_kwargs"])
+ return output_kwargs
+
+ @classmethod
+ def from_pretrained(
+ cls: type[SpecificProcessorType],
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ token: Optional[Union[str, bool]] = None,
+ revision: str = "main",
+ **kwargs,
+ ) -> SpecificProcessorType:
+ r"""
+ Instantiate a processor associated with a pretrained model.
+
+
+
+ This class method is simply calling the feature extractor
+ [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`], image processor
+ [`~image_processing_utils.ImageProcessingMixin`] and the tokenizer
+ [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`] methods. Please refer to the docstrings of the
+ methods above for more information.
+
+
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
+ huggingface.co.
+ - a path to a *directory* containing a feature extractor file saved using the
+ [`~SequenceFeatureExtractor.save_pretrained`] method, e.g., `./my_model_directory/`.
+ - a path or url to a saved feature extractor JSON *file*, e.g.,
+ `./my_model_directory/preprocessor_config.json`.
+ **kwargs
+ Additional keyword arguments passed along to both
+ [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] and
+ [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`].
+ """
+ kwargs["cache_dir"] = cache_dir
+ kwargs["force_download"] = force_download
+ kwargs["local_files_only"] = local_files_only
+ kwargs["revision"] = revision
+
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ if token is not None:
+ kwargs["token"] = token
+
+ args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
+ processor_dict, kwargs = cls.get_processor_dict(pretrained_model_name_or_path, **kwargs)
+ return cls.from_args_and_dict(args, processor_dict, **kwargs)
+
+ @classmethod
+ def register_for_auto_class(cls, auto_class="AutoProcessor"):
+ """
+ Register this class with a given auto class. This should only be used for custom feature extractors as the ones
+ in the library are already mapped with `AutoProcessor`.
+
+
+
+ Args:
+ auto_class (`str` or `type`, *optional*, defaults to `"AutoProcessor"`):
+ The auto class to register this new feature extractor with.
+ """
+ if not isinstance(auto_class, str):
+ auto_class = auto_class.__name__
+
+ import transformers.models.auto as auto_module
+
+ if not hasattr(auto_module, auto_class):
+ raise ValueError(f"{auto_class} is not a valid auto class.")
+
+ cls._auto_class = auto_class
+
+ @classmethod
+ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+ """
+ Identify and instantiate the subcomponents of Processor classes, like image processors and
+ tokenizers. This method uses the Processor attributes like `tokenizer_class` to figure out what class those
+ subcomponents should be. Note that any subcomponents must either be library classes that are accessible in
+ the `transformers` root, or they must be custom code that has been registered with the relevant autoclass,
+ via methods like `AutoTokenizer.register()`. If neither of these conditions are fulfilled, this method
+ will be unable to find the relevant subcomponent class and will raise an error.
+ """
+ args = []
+ for attribute_name in cls.attributes:
+ class_name = getattr(cls, f"{attribute_name}_class")
+ if isinstance(class_name, tuple):
+ classes = tuple(cls.get_possibly_dynamic_module(n) if n is not None else None for n in class_name)
+ if attribute_name == "image_processor":
+ # TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
+ use_fast = kwargs.get("use_fast")
+ if use_fast is None:
+ logger.warning_once(
+ "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
+ "`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. "
+ "This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
+ )
+ else:
+ use_fast = kwargs.get("use_fast", True)
+ if use_fast and classes[1] is not None:
+ attribute_class = classes[1]
+ else:
+ attribute_class = classes[0]
+ else:
+ attribute_class = cls.get_possibly_dynamic_module(class_name)
+
+ args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
+
+ return args
+
+ @staticmethod
+ def get_possibly_dynamic_module(module_name):
+ if hasattr(transformers_module, module_name):
+ return getattr(transformers_module, module_name)
+ lookup_locations = [
+ transformers_module.IMAGE_PROCESSOR_MAPPING,
+ transformers_module.VIDEO_PROCESSOR_MAPPING,
+ transformers_module.TOKENIZER_MAPPING,
+ transformers_module.FEATURE_EXTRACTOR_MAPPING,
+ transformers_module.MODEL_FOR_AUDIO_TOKENIZATION_MAPPING,
+ ]
+ for lookup_location in lookup_locations:
+ for custom_class in lookup_location._extra_content.values():
+ if isinstance(custom_class, tuple):
+ for custom_subclass in custom_class:
+ if custom_subclass is not None and custom_subclass.__name__ == module_name:
+ return custom_subclass
+ elif custom_class is not None and custom_class.__name__ == module_name:
+ return custom_class
+ raise ValueError(
+ f"Could not find module {module_name} in `transformers`. If this is a custom class, "
+ f"it should be registered using the relevant `AutoClass.register()` function so that "
+ f"other functions can find it!"
+ )
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ if not hasattr(self, "tokenizer"):
+ raise ValueError(f"Cannot batch decode text: {self.__class__.__name__} has no tokenizer.")
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ if not hasattr(self, "tokenizer"):
+ raise ValueError(f"Cannot decode text: {self.__class__.__name__} has no tokenizer.")
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @property
+ def model_input_names(self):
+ model_input_names = []
+ for attribute_name in self.attributes:
+ attribute = getattr(self, attribute_name, None)
+ attr_input_names = getattr(attribute, "model_input_names")
+ model_input_names.extend(attr_input_names)
+ return model_input_names
+
+ @staticmethod
+ def validate_init_kwargs(processor_config, valid_kwargs):
+ kwargs_from_config = set(processor_config.keys())
+ valid_kwargs_set = set(valid_kwargs)
+
+ unused_keys = kwargs_from_config - valid_kwargs_set
+ valid_keys = kwargs_from_config & valid_kwargs_set
+
+ unused_kwargs = {k: processor_config[k] for k in unused_keys} if unused_keys else {}
+ valid_kwargs = {k: processor_config[k] for k in valid_keys} if valid_keys else {}
+
+ return unused_kwargs, valid_kwargs
+
+ @deprecate_kwarg("video_fps", version="4.58", new_name="fps")
+ @deprecate_kwarg(
+ "video_load_backend",
+ version="4.59",
+ additional_message=". This function will use `torchcodec` by default, or `torchvision` if `torchcodec` is not installed.",
+ )
+ def apply_chat_template(
+ self,
+ conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
+ chat_template: Optional[str] = None,
+ **kwargs: Unpack[AllKwargsForChatTemplate],
+ ) -> str:
+ """
+ Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input
+ conversations to turn them into a single tokenizable string.
+
+ The input is expected to be in the following format, where each message content is a list consisting of text and
+ optionally image or video inputs. One can also provide an image, video, URL or local path which will be used to form
+ `pixel_values` when `return_dict=True`. If not provided, one will get only the formatted text, optionally tokenized text.
+
+ conversation = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
+ {"type": "text", "text": "Please describe this image in detail."},
+ ],
+ },
+ ]
+
+ Args:
+ conversation (`Union[list[Dict, [str, str]], list[list[dict[str, str]]]]`):
+ The conversation to format.
+ chat_template (`Optional[str]`, *optional*):
+ The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
+ chat template is used.
+ """
+ if chat_template is None:
+ if isinstance(self.chat_template, dict) and "default" in self.chat_template:
+ chat_template = self.chat_template["default"]
+ elif isinstance(self.chat_template, dict):
+ raise ValueError(
+ 'The processor has multiple chat templates but none of them are named "default". You need to specify'
+ " which one to use by passing the `chat_template` argument. Available templates are: "
+ f"{', '.join(self.chat_template.keys())}"
+ )
+ elif self.chat_template is not None:
+ chat_template = self.chat_template
+ else:
+ raise ValueError(
+ "Cannot use apply_chat_template because this processor does not have a chat template."
+ )
+ else:
+ if isinstance(self.chat_template, dict) and chat_template in self.chat_template:
+ # It's the name of a template, not a full template string
+ chat_template = self.chat_template[chat_template]
+ else:
+ # It's a template string, render it directly
+ pass
+
+ is_tokenizers_fast = hasattr(self, "tokenizer") and self.tokenizer.__class__.__name__.endswith("Fast")
+
+ if kwargs.get("continue_final_message", False):
+ if kwargs.get("add_generation_prompt", False):
+ raise ValueError(
+ "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead."
+ )
+ if kwargs.get("return_assistant_tokens_mask", False):
+ raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.")
+
+ if kwargs.get("return_assistant_tokens_mask", False):
+ if not is_tokenizers_fast:
+ raise ValueError(
+ "`return_assistant_tokens_mask` is not possible with slow tokenizers. Make sure you have `tokenizers` installed. "
+ "If the error persists, open an issue to support a Fast tokenizer for your model."
+ )
+ else:
+ kwargs["return_offsets_mapping"] = True # force offset mapping so we can infer token boundaries
+
+ # Fill sets of kwargs that should be used by different parts of template
+ processed_kwargs = {
+ "mm_load_kwargs": {},
+ "template_kwargs": {},
+ }
+
+ for kwarg_type in processed_kwargs:
+ for key in AllKwargsForChatTemplate.__annotations__[kwarg_type].__annotations__:
+ kwarg_type_defaults = AllKwargsForChatTemplate.__annotations__[kwarg_type]
+ default_value = getattr(kwarg_type_defaults, key, None)
+ value = kwargs.pop(key, default_value)
+ if value is not None and not isinstance(value, dict):
+ processed_kwargs[kwarg_type][key] = value
+
+ # pop unused and deprecated kwarg
+ kwargs.pop("video_load_backend", None)
+
+ # Pass unprocessed custom kwargs
+ processed_kwargs["template_kwargs"].update(kwargs)
+
+ if isinstance(conversation, (list, tuple)) and (
+ isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
+ ):
+ is_batched = True
+ conversations = conversation
+ else:
+ is_batched = False
+ conversations = [conversation]
+
+ tokenize = processed_kwargs["template_kwargs"].pop("tokenize", False)
+ return_dict = processed_kwargs["template_kwargs"].pop("return_dict", False)
+ mm_load_kwargs = processed_kwargs["mm_load_kwargs"]
+
+ if tokenize:
+ batch_images, batch_videos = [], []
+ batch_audios = []
+ for conversation in conversations:
+ images, videos = [], []
+ for message in conversation:
+ visuals = [content for content in message["content"] if content["type"] in ["image", "video"]]
+ audio_fnames = [
+ content[key]
+ for content in message["content"]
+ for key in ["audio", "url", "path"]
+ if key in content and content["type"] == "audio"
+ ]
+ image_fnames = [
+ vision_info[key]
+ for vision_info in visuals
+ for key in ["image", "url", "path", "base64"]
+ if key in vision_info and vision_info["type"] == "image"
+ ]
+ images.extend(image_fnames)
+ video_fnames = [
+ vision_info[key]
+ for vision_info in visuals
+ for key in ["video", "url", "path"]
+ if key in vision_info and vision_info["type"] == "video"
+ ]
+ videos.extend(video_fnames)
+
+ # Audio models do not accept nested list of audios (yet!) so we construct a flat input audio list
+ if not mm_load_kwargs["load_audio_from_video"]:
+ for fname in audio_fnames:
+ batch_audios.append(load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"]))
+ else:
+ for fname in video_fnames:
+ batch_audios.append(load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"]))
+
+ # Currently all processors can accept nested list of batches, but not flat list of visuals
+ # So we'll make a batched list of images and let the processor handle it
+ batch_images.append(images)
+ batch_videos.append(videos)
+
+ prompt, generation_indices = render_jinja_template(
+ conversations=conversations,
+ chat_template=chat_template,
+ **processed_kwargs["template_kwargs"], # different flags such as `return_assistant_mask`
+ **self.tokenizer.special_tokens_map, # tokenizer special tokens are used by some templates
+ )
+
+ if not is_batched:
+ prompt = prompt[0]
+
+ if tokenize:
+ # Tokenizer's `apply_chat_template` never adds special tokens when tokenizing
+ # But processor's `apply_chat_template` didn't have an option to tokenize, so users had to format the prompt
+ # and pass it to the processor. Users thus never worried about special tokens relying on processor handling
+ # everything internally. The below line is to keep BC for that and be able to work with model that have
+ # special tokens in the template (consistent with tokenizers). We dont want to raise warning, it will flood command line
+ # without actionable solution for users
+ single_prompt = prompt[0] if is_batched else prompt
+ if self.tokenizer.bos_token is not None and single_prompt.startswith(self.tokenizer.bos_token):
+ kwargs["add_special_tokens"] = False
+
+ # Always sample frames by default unless explicitly set to `False` by users. If users do not pass `num_frames`/`fps`
+ # sampling should not done for BC.
+ if "do_sample_frames" not in kwargs and (
+ kwargs.get("fps") is not None or kwargs.get("num_frames") is not None
+ ):
+ kwargs["do_sample_frames"] = True
+
+ images_exist = any((im is not None) for im_list in batch_images for im in im_list)
+ videos_exist = any((vid is not None) for vid_list in batch_videos for vid in vid_list)
+ out = self(
+ text=prompt,
+ images=batch_images if images_exist else None,
+ videos=batch_videos if videos_exist else None,
+ audio=batch_audios if batch_audios else None,
+ **kwargs,
+ )
+
+ if return_dict:
+ if processed_kwargs["template_kwargs"].get("return_assistant_tokens_mask", False):
+ assistant_masks = []
+ offset_mapping = out.pop("offset_mapping")
+ input_ids = out["input_ids"]
+ for i in range(len(input_ids)):
+ current_mask = [0] * len(input_ids[i])
+ offsets = offset_mapping[i]
+ offset_starts = [start for start, end in offsets]
+ for assistant_start_char, assistant_end_char in generation_indices[i]:
+ start_pos = bisect.bisect_left(offset_starts, assistant_start_char)
+ end_pos = bisect.bisect_left(offset_starts, assistant_end_char)
+
+ if not (
+ start_pos >= 0
+ and offsets[start_pos][0] <= assistant_start_char < offsets[start_pos][1]
+ ):
+ # start_token is out of bounds maybe due to truncation.
+ continue
+ for token_id in range(start_pos, end_pos if end_pos else len(input_ids[i])):
+ current_mask[token_id] = 1
+ assistant_masks.append(current_mask)
+ out["assistant_masks"] = assistant_masks
+ out.convert_to_tensors(tensor_type=kwargs.get("return_tensors"))
+ return out
+ else:
+ return out["input_ids"]
+ return prompt
+
+ def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=True, **kwargs):
+ """
+ Post-process the output of a vlm to decode the text.
+
+ Args:
+ generated_outputs (`torch.Tensor` or `np.ndarray`):
+ The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
+ or `(sequence_length,)`.
+ skip_special_tokens (`bool`, *optional*, defaults to `True`):
+ Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
+ **kwargs:
+ Additional arguments to be passed to the tokenizer's `batch_decode method`.
+
+ Returns:
+ `list[str]`: The decoded text.
+ """
+ return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs)
+
+ def _check_special_mm_tokens(self, text: list[str], text_inputs: "BatchFeature", modalities: list[str]):
+ """
+ Checks that number of special tokens in text and processed text is same. The count can be different
+ if tokenized text was truncated, leading to issues in model code.
+ """
+ for modality in modalities:
+ token_str = getattr(self, f"{modality}_token")
+ token_id = getattr(self, f"{modality}_token_id")
+ ids_count = [list(ids).count(token_id) for ids in text_inputs["input_ids"]]
+ text_count = [sample.count(token_str) for sample in text]
+
+ if ids_count != text_count:
+ raise ValueError(
+ f"Mismatch in `{modality}` token count between text and `input_ids`. Got ids={ids_count} and text={text_count}. "
+ "Likely due to `truncation='max_length'`. Please disable truncation or increase `max_length`."
+ )
+
+
+ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
+if ProcessorMixin.push_to_hub.__doc__ is not None:
+ ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format(
+ object="processor", object_class="AutoProcessor", object_files="processor files"
+ )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/py.typed b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/py.typed
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/pytorch_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/pytorch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1f41117d4cfeb5242a83197ea4e2b04b2d8e9a7
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/pytorch_utils.py
@@ -0,0 +1,380 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import annotations
+
+import inspect
+from functools import lru_cache, wraps
+from typing import Callable
+
+import torch
+from safetensors.torch import storage_ptr, storage_size
+from torch import nn
+
+from .utils import (
+ is_torch_greater_or_equal,
+ is_torch_xla_available,
+ is_torch_xpu_available,
+ is_torchdynamo_compiling,
+ logging,
+)
+
+
+ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
+
+logger = logging.get_logger(__name__)
+
+is_torch_greater_or_equal_than_2_8 = is_torch_greater_or_equal("2.8", accept_dev=True)
+is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
+is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True)
+is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)
+
+# For backwards compatibility (e.g. some remote codes on Hub using those variables).
+is_torch_greater_or_equal_than_2_2 = is_torch_greater_or_equal("2.2", accept_dev=True)
+is_torch_greater_or_equal_than_2_1 = is_torch_greater_or_equal("2.1", accept_dev=True)
+is_torch_greater_or_equal_than_2_0 = is_torch_greater_or_equal("2.0", accept_dev=True)
+is_torch_greater_or_equal_than_1_13 = is_torch_greater_or_equal("1.13", accept_dev=True)
+is_torch_greater_or_equal_than_1_12 = is_torch_greater_or_equal("1.12", accept_dev=True)
+
+# Cache this result has it's a C FFI call which can be pretty time-consuming
+_torch_distributed_available = torch.distributed.is_available()
+
+
+def softmax_backward_data(parent, grad_output, output):
+ """
+ A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according
+ to the torch version detected.
+ """
+
+ from torch import _softmax_backward_data
+
+ return _softmax_backward_data(grad_output, output, parent.dim, output.dtype)
+
+
+def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
+ """
+ Prune a linear layer to keep only entries in index.
+
+ Used to remove heads.
+
+ Args:
+ layer (`torch.nn.Linear`): The layer to prune.
+ index (`torch.LongTensor`): The indices to keep in the layer.
+ dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices.
+
+ Returns:
+ `torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`.
+ """
+ index = index.to(layer.weight.device)
+ W = layer.weight.index_select(dim, index).detach().clone()
+ if layer.bias is not None:
+ if dim == 1:
+ b = layer.bias.detach().clone()
+ else:
+ b = layer.bias[index].detach().clone()
+ new_size = list(layer.weight.size())
+ new_size[dim] = len(index)
+ new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
+ new_layer.weight.requires_grad = False
+ new_layer.weight.copy_(W.contiguous())
+ new_layer.weight.requires_grad = True
+ if layer.bias is not None:
+ new_layer.bias.requires_grad = False
+ new_layer.bias.copy_(b.contiguous())
+ new_layer.bias.requires_grad = True
+ return new_layer
+
+
+class Conv1D(nn.Module):
+ """
+ 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
+
+ Basically works like a linear layer but the weights are transposed.
+
+ Args:
+ nf (`int`): The number of output features.
+ nx (`int`): The number of input features.
+ """
+
+ def __init__(self, nf, nx):
+ super().__init__()
+ self.nf = nf
+ self.nx = nx
+ self.weight = nn.Parameter(torch.empty(nx, nf))
+ self.bias = nn.Parameter(torch.zeros(nf))
+ nn.init.normal_(self.weight, std=0.02)
+
+ def __repr__(self) -> str:
+ return "Conv1D(nf={nf}, nx={nx})".format(**self.__dict__)
+
+ def forward(self, x):
+ size_out = x.size()[:-1] + (self.nf,)
+ x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
+ x = x.view(size_out)
+ return x
+
+
+def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D:
+ """
+ Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights
+ are transposed.
+
+ Used to remove heads.
+
+ Args:
+ layer ([`~pytorch_utils.Conv1D`]): The layer to prune.
+ index (`torch.LongTensor`): The indices to keep in the layer.
+ dim (`int`, *optional*, defaults to 1): The dimension on which to keep the indices.
+
+ Returns:
+ [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
+ """
+ index = index.to(layer.weight.device)
+ W = layer.weight.index_select(dim, index).detach().clone()
+ if dim == 0:
+ b = layer.bias.detach().clone()
+ else:
+ b = layer.bias[index].detach().clone()
+ new_size = list(layer.weight.size())
+ new_size[dim] = len(index)
+ new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
+ new_layer.weight.requires_grad = False
+ new_layer.weight.copy_(W.contiguous())
+ new_layer.weight.requires_grad = True
+ new_layer.bias.requires_grad = False
+ new_layer.bias.copy_(b.contiguous())
+ new_layer.bias.requires_grad = True
+ return new_layer
+
+
+def prune_layer(layer: nn.Linear | Conv1D, index: torch.LongTensor, dim: int | None = None) -> nn.Linear | Conv1D:
+ """
+ Prune a Conv1D or linear layer to keep only entries in index.
+
+ Used to remove heads.
+
+ Args:
+ layer (`Union[torch.nn.Linear, Conv1D]`): The layer to prune.
+ index (`torch.LongTensor`): The indices to keep in the layer.
+ dim (`int`, *optional*): The dimension on which to keep the indices.
+
+ Returns:
+ `torch.nn.Linear` or [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
+ """
+ if isinstance(layer, nn.Linear):
+ return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
+ elif isinstance(layer, Conv1D):
+ return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
+ else:
+ raise ValueError(f"Can't prune layer of class {layer.__class__}")
+
+
+def apply_chunking_to_forward(
+ forward_fn: Callable[..., torch.Tensor],
+ chunk_size: int,
+ chunk_dim: int,
+ *input_tensors,
+) -> torch.Tensor:
+ """
+ This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
+ `chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.
+
+ If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
+ applying `forward_fn` to `input_tensors`.
+
+ Args:
+ forward_fn (`Callable[..., torch.Tensor]`):
+ The forward function of the model.
+ chunk_size (`int`):
+ The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
+ chunk_dim (`int`):
+ The dimension over which the `input_tensors` should be chunked.
+ input_tensors (`tuple[torch.Tensor]`):
+ The input tensors of `forward_fn` which will be chunked
+
+ Returns:
+ `torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`.
+
+
+ Examples:
+
+ ```python
+ # rename the usual forward() fn to forward_chunk()
+ def forward_chunk(self, hidden_states):
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+ # implement a chunked forward function
+ def forward(self, hidden_states):
+ return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
+ ```"""
+
+ assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"
+
+ # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
+ num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
+ if num_args_in_forward_chunk_fn != len(input_tensors):
+ raise ValueError(
+ f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
+ "tensors are given"
+ )
+
+ if chunk_size > 0:
+ tensor_shape = input_tensors[0].shape[chunk_dim]
+ for input_tensor in input_tensors:
+ if input_tensor.shape[chunk_dim] != tensor_shape:
+ raise ValueError(
+ f"All input tenors have to be of the same shape: {tensor_shape}, "
+ f"found shape {input_tensor.shape[chunk_dim]}"
+ )
+
+ if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
+ raise ValueError(
+ f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
+ f"size {chunk_size}"
+ )
+
+ num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
+
+ # chunk input tensor into tuples
+ input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
+ # apply forward fn to every tuple
+ output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
+ # concatenate output at same dimension
+ return torch.cat(output_chunks, dim=chunk_dim)
+
+ return forward_fn(*input_tensors)
+
+
+def find_pruneable_heads_and_indices(
+ heads: list[int], n_heads: int, head_size: int, already_pruned_heads: set[int]
+) -> tuple[set[int], torch.LongTensor]:
+ """
+ Finds the heads and their indices taking `already_pruned_heads` into account.
+
+ Args:
+ heads (`list[int]`): List of the indices of heads to prune.
+ n_heads (`int`): The number of heads in the model.
+ head_size (`int`): The size of each head.
+ already_pruned_heads (`Set[int]`): A set of already pruned heads.
+
+ Returns:
+ `tuple[Set[int], torch.LongTensor]`: A tuple with the indices of heads to prune taking `already_pruned_heads`
+ into account and the indices of rows/columns to keep in the layer weight.
+ """
+ mask = torch.ones(n_heads, head_size)
+ heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
+ for head in heads:
+ # Compute how many pruned heads are before the head and move the index accordingly
+ head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
+ mask[head] = 0
+ mask = mask.view(-1).contiguous().eq(1)
+ index: torch.LongTensor = torch.arange(len(mask))[mask].long()
+ return heads, index
+
+
+def meshgrid(*tensors: torch.Tensor | list[torch.Tensor], indexing: str | None = None) -> tuple[torch.Tensor, ...]:
+ """
+ Wrapper around torch.meshgrid to avoid warning messages about the introduced `indexing` argument.
+
+ Reference: https://pytorch.org/docs/1.13/generated/torch.meshgrid.html
+ """
+ return torch.meshgrid(*tensors, indexing=indexing)
+
+
+def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
+ """
+ Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For
+ example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
+ guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
+ non-overlapping lifetimes may have the same id.
+ """
+ if _torch_distributed_available and is_torch_greater_or_equal("2.5"):
+ from torch.distributed.tensor import DTensor
+
+ if isinstance(tensor, DTensor):
+ local_tensor = tensor.to_local()
+ return tensor.device, local_tensor.storage().data_ptr(), tensor.nbytes
+
+ if tensor.device.type == "xla" and is_torch_xla_available():
+ # NOTE: xla tensors dont have storage
+ # use some other unique id to distinguish.
+ # this is a XLA tensor, it must be created using torch_xla's
+ # device. So the following import is safe:
+ import torch_xla
+
+ unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor)
+ else:
+ unique_id = storage_ptr(tensor)
+
+ return tensor.device, unique_id, storage_size(tensor)
+
+
+def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int) -> torch.Tensor:
+ """
+ Same as `torch.isin` without flags, but MPS-friendly. We can remove this function when we stop supporting
+ torch <= 2.3. See https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
+
+ Args:
+ elements (`torch.Tensor`): Input elements
+ test_elements (`torch.Tensor` or `int`): The elements to check against.
+
+ Returns:
+ `torch.Tensor`: A boolean tensor of the same shape as `elements` that is True for `elements` in `test_elements`
+ and False otherwise
+ """
+
+ if elements.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
+ test_elements = torch.tensor(test_elements)
+ if test_elements.ndim == 0:
+ test_elements = test_elements.unsqueeze(0)
+ return elements.tile(test_elements.shape[0], 1).eq(test_elements.unsqueeze(1)).sum(dim=0).bool().squeeze()
+ else:
+ # Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045
+ return torch.isin(elements, test_elements)
+
+
+@wraps(lru_cache)
+def compile_compatible_method_lru_cache(*lru_args, **lru_kwargs):
+ """
+ LRU cache decorator from standard functools library, but with a workaround to disable
+ caching when torchdynamo is compiling. Expected to work with class methods.
+ """
+
+ def decorator(func):
+ func_with_cache = lru_cache(*lru_args, **lru_kwargs)(func)
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if is_torchdynamo_compiling():
+ return func(*args, **kwargs)
+ else:
+ return func_with_cache(*args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+def infer_device():
+ """
+ Infers available device.
+ """
+ torch_device = "cpu"
+ if torch.cuda.is_available():
+ torch_device = "cuda"
+ elif is_torch_xpu_available():
+ torch_device = "xpu"
+
+ return torch_device
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/safetensors_conversion.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/safetensors_conversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1612d3ea57c98fd1d383887cfbeb4e2882d3963
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/safetensors_conversion.py
@@ -0,0 +1,105 @@
+from typing import Optional
+
+import requests
+from huggingface_hub import Discussion, HfApi, get_repo_discussions
+
+from .utils import cached_file, http_user_agent, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def previous_pr(api: HfApi, model_id: str, pr_title: str, token: str) -> Optional["Discussion"]:
+ main_commit = api.list_repo_commits(model_id, token=token)[0].commit_id
+ for discussion in get_repo_discussions(repo_id=model_id, token=token):
+ if discussion.title == pr_title and discussion.status == "open" and discussion.is_pull_request:
+ commits = api.list_repo_commits(model_id, revision=discussion.git_reference, token=token)
+
+ if main_commit == commits[1].commit_id:
+ return discussion
+ return None
+
+
+def spawn_conversion(token: str, private: bool, model_id: str):
+ logger.info("Attempting to convert .bin model on the fly to safetensors.")
+
+ safetensors_convert_space_url = "https://safetensors-convert.hf.space"
+ sse_url = f"{safetensors_convert_space_url}/call/run"
+
+ def start(_sse_connection):
+ for line in _sse_connection.iter_lines():
+ line = line.decode()
+ if line.startswith("event:"):
+ status = line[7:]
+ logger.debug(f"Safetensors conversion status: {status}")
+
+ if status == "complete":
+ return
+ elif status == "heartbeat":
+ logger.debug("Heartbeat")
+ else:
+ logger.debug(f"Unknown status {status}")
+ else:
+ logger.debug(line)
+
+ data = {"data": [model_id, private, token]}
+
+ result = requests.post(sse_url, stream=True, json=data).json()
+ event_id = result["event_id"]
+
+ with requests.get(f"{sse_url}/{event_id}", stream=True) as sse_connection:
+ try:
+ logger.debug("Spawning safetensors automatic conversion.")
+ start(sse_connection)
+ except Exception as e:
+ logger.warning(f"Error during conversion: {repr(e)}")
+
+
+def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs):
+ private = api.model_info(model_id).private
+
+ logger.info("Attempting to create safetensors variant")
+ pr_title = "Adding `safetensors` variant of this model"
+ token = kwargs.get("token")
+
+ # This looks into the current repo's open PRs to see if a PR for safetensors was already open. If so, it
+ # returns it. It checks that the PR was opened by the bot and not by another user so as to prevent
+ # security breaches.
+ pr = previous_pr(api, model_id, pr_title, token=token)
+
+ if pr is None or (not private and pr.author != "SFconvertbot"):
+ spawn_conversion(token, private, model_id)
+ pr = previous_pr(api, model_id, pr_title, token=token)
+ else:
+ logger.info("Safetensors PR exists")
+
+ sha = f"refs/pr/{pr.num}"
+
+ return sha
+
+
+def auto_conversion(pretrained_model_name_or_path: str, ignore_errors_during_conversion=False, **cached_file_kwargs):
+ try:
+ api = HfApi(token=cached_file_kwargs.get("token"), headers={"user-agent": http_user_agent()})
+ sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs)
+
+ if sha is None:
+ return None, None
+ cached_file_kwargs["revision"] = sha
+ del cached_file_kwargs["_commit_hash"]
+
+ # This is an additional HEAD call that could be removed if we could infer sharded/non-sharded from the PR
+ # description.
+ sharded = api.file_exists(
+ pretrained_model_name_or_path,
+ "model.safetensors.index.json",
+ revision=sha,
+ token=cached_file_kwargs.get("token"),
+ )
+ filename = "model.safetensors.index.json" if sharded else "model.safetensors"
+
+ resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
+ return resolved_archive_file, sha, sharded
+ except Exception as e:
+ if not ignore_errors_during_conversion:
+ raise e
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/testing_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/testing_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..499716789b9bc5197b2669bb4099c13741f4fd79
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/testing_utils.py
@@ -0,0 +1,4154 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import ast
+import collections
+import contextlib
+import copy
+import doctest
+import functools
+import gc
+import importlib
+import inspect
+import logging
+import multiprocessing
+import os
+import re
+import shlex
+import shutil
+import subprocess
+import sys
+import tempfile
+import threading
+import time
+import traceback
+import types
+import unittest
+from collections import UserDict, defaultdict
+from collections.abc import Generator, Iterable, Iterator, Mapping
+from dataclasses import MISSING, fields
+from functools import cache, wraps
+from io import StringIO
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+from unittest import mock
+from unittest.mock import patch
+
+import huggingface_hub.utils
+import requests
+import urllib3
+from huggingface_hub import delete_repo
+from packaging import version
+
+from transformers import Trainer
+from transformers import logging as transformers_logging
+
+from .integrations import (
+ is_clearml_available,
+ is_optuna_available,
+ is_ray_available,
+ is_sigopt_available,
+ is_swanlab_available,
+ is_tensorboard_available,
+ is_trackio_available,
+ is_wandb_available,
+)
+from .integrations.deepspeed import is_deepspeed_available
+from .utils import (
+ ACCELERATE_MIN_VERSION,
+ GGUF_MIN_VERSION,
+ TRITON_MIN_VERSION,
+ is_accelerate_available,
+ is_apex_available,
+ is_apollo_torch_available,
+ is_aqlm_available,
+ is_auto_awq_available,
+ is_auto_gptq_available,
+ is_auto_round_available,
+ is_av_available,
+ is_bitsandbytes_available,
+ is_bitsandbytes_multi_backend_available,
+ is_bs4_available,
+ is_compressed_tensors_available,
+ is_cv2_available,
+ is_cython_available,
+ is_decord_available,
+ is_detectron2_available,
+ is_eetq_available,
+ is_essentia_available,
+ is_faiss_available,
+ is_fbgemm_gpu_available,
+ is_flash_attn_2_available,
+ is_flash_attn_3_available,
+ is_flax_available,
+ is_flute_available,
+ is_fp_quant_available,
+ is_fsdp_available,
+ is_ftfy_available,
+ is_g2p_en_available,
+ is_galore_torch_available,
+ is_gguf_available,
+ is_gptqmodel_available,
+ is_grokadamw_available,
+ is_hadamard_available,
+ is_hqq_available,
+ is_huggingface_hub_greater_or_equal,
+ is_ipex_available,
+ is_jinja_available,
+ is_jumanpp_available,
+ is_keras_nlp_available,
+ is_kernels_available,
+ is_levenshtein_available,
+ is_librosa_available,
+ is_liger_kernel_available,
+ is_lomo_available,
+ is_mistral_common_available,
+ is_natten_available,
+ is_nltk_available,
+ is_onnx_available,
+ is_openai_available,
+ is_optimum_available,
+ is_optimum_quanto_available,
+ is_pandas_available,
+ is_peft_available,
+ is_phonemizer_available,
+ is_pretty_midi_available,
+ is_psutil_available,
+ is_pyctcdecode_available,
+ is_pytesseract_available,
+ is_pytest_available,
+ is_pytorch_quantization_available,
+ is_quark_available,
+ is_qutlass_available,
+ is_rjieba_available,
+ is_sacremoses_available,
+ is_safetensors_available,
+ is_schedulefree_available,
+ is_scipy_available,
+ is_sentencepiece_available,
+ is_seqio_available,
+ is_spacy_available,
+ is_speech_available,
+ is_spqr_available,
+ is_sudachi_available,
+ is_sudachi_projection_available,
+ is_tf_available,
+ is_tiktoken_available,
+ is_timm_available,
+ is_tokenizers_available,
+ is_torch_available,
+ is_torch_bf16_available_on_device,
+ is_torch_bf16_gpu_available,
+ is_torch_fp16_available_on_device,
+ is_torch_greater_or_equal,
+ is_torch_hpu_available,
+ is_torch_mlu_available,
+ is_torch_neuroncore_available,
+ is_torch_npu_available,
+ is_torch_optimi_available,
+ is_torch_tensorrt_fx_available,
+ is_torch_tf32_available,
+ is_torch_xla_available,
+ is_torch_xpu_available,
+ is_torchao_available,
+ is_torchaudio_available,
+ is_torchcodec_available,
+ is_torchdynamo_available,
+ is_torchvision_available,
+ is_triton_available,
+ is_vision_available,
+ is_vptq_available,
+ strtobool,
+)
+
+
+if is_accelerate_available():
+ from accelerate.state import AcceleratorState, PartialState
+ from accelerate.utils.imports import is_fp8_available
+
+
+if is_pytest_available():
+ from _pytest.doctest import (
+ Module,
+ _get_checker,
+ _get_continue_on_failure,
+ _get_runner,
+ _is_mocked,
+ _patch_unwrap_mock_aware,
+ get_optionflags,
+ )
+ from _pytest.outcomes import skip
+ from _pytest.pathlib import import_path
+ from pytest import DoctestItem
+else:
+ Module = object
+ DoctestItem = object
+
+
+SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
+DUMMY_UNKNOWN_IDENTIFIER = "julien-c/dummy-unknown"
+DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"
+# Used to test Auto{Config, Model, Tokenizer} model_type detection.
+
+# Used to test the hub
+USER = "__DUMMY_TRANSFORMERS_USER__"
+ENDPOINT_STAGING = "https://hub-ci.huggingface.co"
+
+# Not critical, only usable on the sandboxed CI instance.
+TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL"
+
+
+# Used in CausalLMModelTester (and related classes/methods) to infer the common model classes from the base model class
+_COMMON_MODEL_NAMES_MAP = {
+ "config_class": "Config",
+ "causal_lm_class": "ForCausalLM",
+ "question_answering_class": "ForQuestionAnswering",
+ "sequence_classification_class": "ForSequenceClassification",
+ "token_classification_class": "ForTokenClassification",
+}
+
+
+if is_torch_available():
+ import torch
+
+ IS_ROCM_SYSTEM = torch.version.hip is not None
+ IS_CUDA_SYSTEM = torch.version.cuda is not None
+ IS_XPU_SYSTEM = getattr(torch.version, "xpu", None) is not None
+else:
+ IS_ROCM_SYSTEM = False
+ IS_CUDA_SYSTEM = False
+ IS_XPU_SYSTEM = False
+
+logger = transformers_logging.get_logger(__name__)
+
+
+def parse_flag_from_env(key, default=False):
+ try:
+ value = os.environ[key]
+ except KeyError:
+ # KEY isn't set, default to `default`.
+ _value = default
+ else:
+ # KEY is set, convert it to True or False.
+ try:
+ _value = strtobool(value)
+ except ValueError:
+ # More values are supported, but let's keep the message simple.
+ raise ValueError(f"If set, {key} must be yes or no.")
+ return _value
+
+
+def parse_int_from_env(key, default=None):
+ try:
+ value = os.environ[key]
+ except KeyError:
+ _value = default
+ else:
+ try:
+ _value = int(value)
+ except ValueError:
+ raise ValueError(f"If set, {key} must be a int.")
+ return _value
+
+
+_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
+_run_flaky_tests = parse_flag_from_env("RUN_FLAKY", default=True)
+_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
+_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
+_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True)
+_run_agent_tests = parse_flag_from_env("RUN_AGENT_TESTS", default=False)
+
+
+def is_staging_test(test_case):
+ """
+ Decorator marking a test as a staging test.
+
+ Those tests will run using the staging environment of huggingface.co instead of the real model hub.
+ """
+ if not _run_staging:
+ return unittest.skip(reason="test is staging test")(test_case)
+ else:
+ try:
+ import pytest # We don't need a hard dependency on pytest in the main library
+ except ImportError:
+ return test_case
+ else:
+ return pytest.mark.is_staging_test()(test_case)
+
+
+def is_pipeline_test(test_case):
+ """
+ Decorator marking a test as a pipeline test. If RUN_PIPELINE_TESTS is set to a falsy value, those tests will be
+ skipped.
+ """
+ if not _run_pipeline_tests:
+ return unittest.skip(reason="test is pipeline test")(test_case)
+ else:
+ try:
+ import pytest # We don't need a hard dependency on pytest in the main library
+ except ImportError:
+ return test_case
+ else:
+ return pytest.mark.is_pipeline_test()(test_case)
+
+
+def is_agent_test(test_case):
+ """
+ Decorator marking a test as an agent test. If RUN_TOOL_TESTS is set to a falsy value, those tests will be skipped.
+ """
+ if not _run_agent_tests:
+ return unittest.skip(reason="test is an agent test")(test_case)
+ else:
+ try:
+ import pytest # We don't need a hard dependency on pytest in the main library
+ except ImportError:
+ return test_case
+ else:
+ return pytest.mark.is_agent_test()(test_case)
+
+
+def slow(test_case):
+ """
+ Decorator marking a test as slow.
+
+ Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
+
+ """
+ return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
+
+
+def tooslow(test_case):
+ """
+ Decorator marking a test as too slow.
+
+ Slow tests are skipped while they're in the process of being fixed. No test should stay tagged as "tooslow" as
+ these will not be tested by the CI.
+
+ """
+ return unittest.skip(reason="test is too slow")(test_case)
+
+
+def skip_if_not_implemented(test_func):
+ @functools.wraps(test_func)
+ def wrapper(*args, **kwargs):
+ try:
+ return test_func(*args, **kwargs)
+ except NotImplementedError as e:
+ raise unittest.SkipTest(f"Test skipped due to NotImplementedError: {e}")
+
+ return wrapper
+
+
+def apply_skip_if_not_implemented(cls):
+ """
+ Class decorator to apply @skip_if_not_implemented to all test methods.
+ """
+ for attr_name in dir(cls):
+ if attr_name.startswith("test_"):
+ attr = getattr(cls, attr_name)
+ if callable(attr):
+ setattr(cls, attr_name, skip_if_not_implemented(attr))
+ return cls
+
+
+def custom_tokenizers(test_case):
+ """
+ Decorator marking a test for a custom tokenizer.
+
+ Custom tokenizers require additional dependencies, and are skipped by default. Set the RUN_CUSTOM_TOKENIZERS
+ environment variable to a truthy value to run them.
+ """
+ return unittest.skipUnless(_run_custom_tokenizers, "test of custom tokenizers")(test_case)
+
+
+def require_bs4(test_case):
+ """
+ Decorator marking a test that requires BeautifulSoup4. These tests are skipped when BeautifulSoup4 isn't installed.
+ """
+ return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case)
+
+
+def require_galore_torch(test_case):
+ """
+ Decorator marking a test that requires GaLore. These tests are skipped when GaLore isn't installed.
+ https://github.com/jiaweizzhao/GaLore
+ """
+ return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)
+
+
+def require_apollo_torch(test_case):
+ """
+ Decorator marking a test that requires GaLore. These tests are skipped when APOLLO isn't installed.
+ https://github.com/zhuhanqing/APOLLO
+ """
+ return unittest.skipUnless(is_apollo_torch_available(), "test requires APOLLO")(test_case)
+
+
+def require_torch_optimi(test_case):
+ """
+ Decorator marking a test that requires torch-optimi. These tests are skipped when torch-optimi isn't installed.
+ https://github.com/jxnl/torch-optimi
+ """
+ return unittest.skipUnless(is_torch_optimi_available(), "test requires torch-optimi")(test_case)
+
+
+def require_lomo(test_case):
+ """
+ Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed.
+ https://github.com/OpenLMLab/LOMO
+ """
+ return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case)
+
+
+def require_grokadamw(test_case):
+ """
+ Decorator marking a test that requires GrokAdamW. These tests are skipped when GrokAdamW isn't installed.
+ """
+ return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case)
+
+
+def require_schedulefree(test_case):
+ """
+ Decorator marking a test that requires schedulefree. These tests are skipped when schedulefree isn't installed.
+ https://github.com/facebookresearch/schedule_free
+ """
+ return unittest.skipUnless(is_schedulefree_available(), "test requires schedulefree")(test_case)
+
+
+def require_cv2(test_case):
+ """
+ Decorator marking a test that requires OpenCV.
+
+ These tests are skipped when OpenCV isn't installed.
+
+ """
+ return unittest.skipUnless(is_cv2_available(), "test requires OpenCV")(test_case)
+
+
+def require_levenshtein(test_case):
+ """
+ Decorator marking a test that requires Levenshtein.
+
+ These tests are skipped when Levenshtein isn't installed.
+
+ """
+ return unittest.skipUnless(is_levenshtein_available(), "test requires Levenshtein")(test_case)
+
+
+def require_nltk(test_case):
+ """
+ Decorator marking a test that requires NLTK.
+
+ These tests are skipped when NLTK isn't installed.
+
+ """
+ return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case)
+
+
+def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION):
+ """
+ Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
+ """
+ return unittest.skipUnless(
+ is_accelerate_available(min_version), f"test requires accelerate version >= {min_version}"
+ )(test_case)
+
+
+def require_triton(min_version: str = TRITON_MIN_VERSION):
+ """
+ Decorator marking a test that requires triton. These tests are skipped when triton isn't installed.
+ """
+
+ def decorator(test_case):
+ return unittest.skipUnless(is_triton_available(min_version), f"test requires triton version >= {min_version}")(
+ test_case
+ )
+
+ return decorator
+
+
+def require_gguf(test_case, min_version: str = GGUF_MIN_VERSION):
+ """
+ Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed.
+ """
+ return unittest.skipUnless(is_gguf_available(min_version), f"test requires gguf version >= {min_version}")(
+ test_case
+ )
+
+
+def require_fsdp(test_case, min_version: str = "1.12.0"):
+ """
+ Decorator marking a test that requires fsdp. These tests are skipped when fsdp isn't installed.
+ """
+ return unittest.skipUnless(is_fsdp_available(min_version), f"test requires torch version >= {min_version}")(
+ test_case
+ )
+
+
+def require_g2p_en(test_case):
+ """
+ Decorator marking a test that requires g2p_en. These tests are skipped when SentencePiece isn't installed.
+ """
+ return unittest.skipUnless(is_g2p_en_available(), "test requires g2p_en")(test_case)
+
+
+def require_safetensors(test_case):
+ """
+ Decorator marking a test that requires safetensors. These tests are skipped when safetensors isn't installed.
+ """
+ return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case)
+
+
+def require_rjieba(test_case):
+ """
+ Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
+ """
+ return unittest.skipUnless(is_rjieba_available(), "test requires rjieba")(test_case)
+
+
+def require_jinja(test_case):
+ """
+ Decorator marking a test that requires jinja. These tests are skipped when jinja isn't installed.
+ """
+ return unittest.skipUnless(is_jinja_available(), "test requires jinja")(test_case)
+
+
+def require_onnx(test_case):
+ return unittest.skipUnless(is_onnx_available(), "test requires ONNX")(test_case)
+
+
+def require_timm(test_case):
+ """
+ Decorator marking a test that requires Timm.
+
+ These tests are skipped when Timm isn't installed.
+
+ """
+ return unittest.skipUnless(is_timm_available(), "test requires Timm")(test_case)
+
+
+def require_natten(test_case):
+ """
+ Decorator marking a test that requires NATTEN.
+
+ These tests are skipped when NATTEN isn't installed.
+
+ """
+ return unittest.skipUnless(is_natten_available(), "test requires natten")(test_case)
+
+
+def require_torch(test_case):
+ """
+ Decorator marking a test that requires PyTorch.
+
+ These tests are skipped when PyTorch isn't installed.
+
+ """
+ return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
+
+
+def require_torch_greater_or_equal(version: str):
+ """
+ Decorator marking a test that requires PyTorch version >= `version`.
+
+ These tests are skipped when PyTorch version is less than `version`.
+ """
+
+ def decorator(test_case):
+ return unittest.skipUnless(is_torch_greater_or_equal(version), f"test requires PyTorch version >= {version}")(
+ test_case
+ )
+
+ return decorator
+
+
+def require_huggingface_hub_greater_or_equal(version: str):
+ """
+ Decorator marking a test that requires huggingface_hub version >= `version`.
+
+ These tests are skipped when huggingface_hub version is less than `version`.
+ """
+
+ def decorator(test_case):
+ return unittest.skipUnless(
+ is_huggingface_hub_greater_or_equal(version), f"test requires huggingface_hub version >= {version}"
+ )(test_case)
+
+ return decorator
+
+
+def require_flash_attn(test_case):
+ """
+ Decorator marking a test that requires Flash Attention.
+
+ These tests are skipped when Flash Attention isn't installed.
+
+ """
+ flash_attn_available = is_flash_attn_2_available()
+ kernels_available = is_kernels_available()
+ try:
+ from kernels import get_kernel
+
+ get_kernel("kernels-community/flash-attn")
+ except Exception as _:
+ kernels_available = False
+
+ return unittest.skipUnless(kernels_available | flash_attn_available, "test requires Flash Attention")(test_case)
+
+
+def require_kernels(test_case):
+ """
+ Decorator marking a test that requires the kernels library.
+
+ These tests are skipped when the kernels library isn't installed.
+
+ """
+ return unittest.skipUnless(is_kernels_available(), "test requires the kernels library")(test_case)
+
+
+def require_flash_attn_3(test_case):
+ """
+ Decorator marking a test that requires Flash Attention 3.
+
+ These tests are skipped when Flash Attention 3 isn't installed.
+ """
+ return unittest.skipUnless(is_flash_attn_3_available(), "test requires Flash Attention 3")(test_case)
+
+
+def require_read_token(test_case):
+ """
+ A decorator that loads the HF token for tests that require to load gated models.
+ """
+ token = os.getenv("HF_HUB_READ_TOKEN")
+
+ if isinstance(test_case, type):
+ for attr_name in dir(test_case):
+ attr = getattr(test_case, attr_name)
+ if isinstance(attr, types.FunctionType):
+ if getattr(attr, "__require_read_token__", False):
+ continue
+ wrapped = require_read_token(attr)
+ setattr(test_case, attr_name, wrapped)
+ return test_case
+ else:
+ if getattr(test_case, "__require_read_token__", False):
+ return test_case
+
+ @functools.wraps(test_case)
+ def wrapper(*args, **kwargs):
+ if token is not None:
+ with patch("huggingface_hub.utils._headers.get_token", return_value=token):
+ return test_case(*args, **kwargs)
+ else: # Allow running locally with the default token env variable
+ # dealing with static/class methods and called by `self.xxx`
+ if "staticmethod" in inspect.getsource(test_case).strip():
+ if len(args) > 0 and isinstance(args[0], unittest.TestCase):
+ return test_case(*args[1:], **kwargs)
+ return test_case(*args, **kwargs)
+
+ wrapper.__require_read_token__ = True
+ return wrapper
+
+
+def require_peft(test_case):
+ """
+ Decorator marking a test that requires PEFT.
+
+ These tests are skipped when PEFT isn't installed.
+
+ """
+ return unittest.skipUnless(is_peft_available(), "test requires PEFT")(test_case)
+
+
+def require_torchvision(test_case):
+ """
+ Decorator marking a test that requires Torchvision.
+
+ These tests are skipped when Torchvision isn't installed.
+
+ """
+ return unittest.skipUnless(is_torchvision_available(), "test requires Torchvision")(test_case)
+
+
+def require_torchcodec(test_case):
+ """
+ Decorator marking a test that requires Torchcodec.
+
+ These tests are skipped when Torchcodec isn't installed.
+
+ """
+ return unittest.skipUnless(is_torchcodec_available(), "test requires Torchcodec")(test_case)
+
+
+def require_torch_or_tf(test_case):
+ """
+ Decorator marking a test that requires PyTorch or TensorFlow.
+
+ These tests are skipped when neither PyTorch not TensorFlow is installed.
+
+ """
+ return unittest.skipUnless(is_torch_available() or is_tf_available(), "test requires PyTorch or TensorFlow")(
+ test_case
+ )
+
+
+def require_intel_extension_for_pytorch(test_case):
+ """
+ Decorator marking a test that requires Intel Extension for PyTorch.
+
+ These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch
+ version.
+
+ """
+ return unittest.skipUnless(
+ is_ipex_available(),
+ "test requires Intel Extension for PyTorch to be installed and match current PyTorch version, see"
+ " https://github.com/intel/intel-extension-for-pytorch",
+ )(test_case)
+
+
+def require_torchaudio(test_case):
+ """
+ Decorator marking a test that requires torchaudio. These tests are skipped when torchaudio isn't installed.
+ """
+ return unittest.skipUnless(is_torchaudio_available(), "test requires torchaudio")(test_case)
+
+
+def require_sentencepiece(test_case):
+ """
+ Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed.
+ """
+ return unittest.skipUnless(is_sentencepiece_available(), "test requires SentencePiece")(test_case)
+
+
+def require_sacremoses(test_case):
+ """
+ Decorator marking a test that requires Sacremoses. These tests are skipped when Sacremoses isn't installed.
+ """
+ return unittest.skipUnless(is_sacremoses_available(), "test requires Sacremoses")(test_case)
+
+
+def require_seqio(test_case):
+ """
+ Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed.
+ """
+ return unittest.skipUnless(is_seqio_available(), "test requires Seqio")(test_case)
+
+
+def require_scipy(test_case):
+ """
+ Decorator marking a test that requires Scipy. These tests are skipped when SentencePiece isn't installed.
+ """
+ return unittest.skipUnless(is_scipy_available(), "test requires Scipy")(test_case)
+
+
+def require_tokenizers(test_case):
+ """
+ Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed.
+ """
+ return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case)
+
+
+def require_keras_nlp(test_case):
+ """
+ Decorator marking a test that requires keras_nlp. These tests are skipped when keras_nlp isn't installed.
+ """
+ return unittest.skipUnless(is_keras_nlp_available(), "test requires keras_nlp")(test_case)
+
+
+def require_pandas(test_case):
+ """
+ Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
+ """
+ return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case)
+
+
+def require_pytesseract(test_case):
+ """
+ Decorator marking a test that requires PyTesseract. These tests are skipped when PyTesseract isn't installed.
+ """
+ return unittest.skipUnless(is_pytesseract_available(), "test requires PyTesseract")(test_case)
+
+
+def require_pytorch_quantization(test_case):
+ """
+ Decorator marking a test that requires PyTorch Quantization Toolkit. These tests are skipped when PyTorch
+ Quantization Toolkit isn't installed.
+ """
+ return unittest.skipUnless(is_pytorch_quantization_available(), "test requires PyTorch Quantization Toolkit")(
+ test_case
+ )
+
+
+def require_vision(test_case):
+ """
+ Decorator marking a test that requires the vision dependencies. These tests are skipped when torchaudio isn't
+ installed.
+ """
+ return unittest.skipUnless(is_vision_available(), "test requires vision")(test_case)
+
+
+def require_ftfy(test_case):
+ """
+ Decorator marking a test that requires ftfy. These tests are skipped when ftfy isn't installed.
+ """
+ return unittest.skipUnless(is_ftfy_available(), "test requires ftfy")(test_case)
+
+
+def require_spacy(test_case):
+ """
+ Decorator marking a test that requires SpaCy. These tests are skipped when SpaCy isn't installed.
+ """
+ return unittest.skipUnless(is_spacy_available(), "test requires spacy")(test_case)
+
+
+def require_torch_multi_gpu(test_case):
+ """
+ Decorator marking a test that requires a multi-GPU CUDA setup (in PyTorch). These tests are skipped on a machine without
+ multiple CUDA GPUs.
+
+ To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
+ """
+ if not is_torch_available():
+ return unittest.skip(reason="test requires PyTorch")(test_case)
+
+ import torch
+
+ return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple CUDA GPUs")(test_case)
+
+
+def require_torch_multi_accelerator(test_case):
+ """
+ Decorator marking a test that requires a multi-accelerator (in PyTorch). These tests are skipped on a machine
+ without multiple accelerators. To run *only* the multi_accelerator tests, assuming all test names contain
+ multi_accelerator: $ pytest -sv ./tests -k "multi_accelerator"
+ """
+ if not is_torch_available():
+ return unittest.skip(reason="test requires PyTorch")(test_case)
+
+ return unittest.skipUnless(backend_device_count(torch_device) > 1, "test requires multiple accelerators")(
+ test_case
+ )
+
+
+def require_torch_non_multi_gpu(test_case):
+ """
+ Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
+ """
+ if not is_torch_available():
+ return unittest.skip(reason="test requires PyTorch")(test_case)
+
+ import torch
+
+ return unittest.skipUnless(torch.cuda.device_count() < 2, "test requires 0 or 1 GPU")(test_case)
+
+
+def require_torch_non_multi_accelerator(test_case):
+ """
+ Decorator marking a test that requires 0 or 1 accelerator setup (in PyTorch).
+ """
+ if not is_torch_available():
+ return unittest.skip(reason="test requires PyTorch")(test_case)
+
+ return unittest.skipUnless(backend_device_count(torch_device) < 2, "test requires 0 or 1 accelerator")(test_case)
+
+
+def require_torch_up_to_2_gpus(test_case):
+ """
+ Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch).
+ """
+ if not is_torch_available():
+ return unittest.skip(reason="test requires PyTorch")(test_case)
+
+ import torch
+
+ return unittest.skipUnless(torch.cuda.device_count() < 3, "test requires 0 or 1 or 2 GPUs")(test_case)
+
+
+def require_torch_up_to_2_accelerators(test_case):
+ """
+ Decorator marking a test that requires 0 or 1 or 2 accelerator setup (in PyTorch).
+ """
+ if not is_torch_available():
+ return unittest.skip(reason="test requires PyTorch")(test_case)
+
+ return unittest.skipUnless(backend_device_count(torch_device) < 3, "test requires 0 or 1 or 2 accelerators")(
+ test_case
+ )
+
+
+def require_torch_xla(test_case):
+ """
+ Decorator marking a test that requires TorchXLA (in PyTorch).
+ """
+ return unittest.skipUnless(is_torch_xla_available(), "test requires TorchXLA")(test_case)
+
+
+def require_torch_neuroncore(test_case):
+ """
+ Decorator marking a test that requires NeuronCore (in PyTorch).
+ """
+ return unittest.skipUnless(is_torch_neuroncore_available(check_device=False), "test requires PyTorch NeuronCore")(
+ test_case
+ )
+
+
+def require_torch_npu(test_case):
+ """
+ Decorator marking a test that requires NPU (in PyTorch).
+ """
+ return unittest.skipUnless(is_torch_npu_available(), "test requires PyTorch NPU")(test_case)
+
+
+def require_torch_multi_npu(test_case):
+ """
+ Decorator marking a test that requires a multi-NPU setup (in PyTorch). These tests are skipped on a machine without
+ multiple NPUs.
+
+ To run *only* the multi_npu tests, assuming all test names contain multi_npu: $ pytest -sv ./tests -k "multi_npu"
+ """
+ if not is_torch_npu_available():
+ return unittest.skip(reason="test requires PyTorch NPU")(test_case)
+
+ return unittest.skipUnless(torch.npu.device_count() > 1, "test requires multiple NPUs")(test_case)
+
+
+def require_non_hpu(test_case):
+ """
+ Decorator marking a test that should be skipped for HPU.
+ """
+ return unittest.skipUnless(torch_device != "hpu", "test requires a non-HPU")(test_case)
+
+
+def require_torch_xpu(test_case):
+ """
+ Decorator marking a test that requires XPU (in PyTorch).
+
+ These tests are skipped when XPU backend is not available. XPU backend might be available either via stock
+ PyTorch (>=2.4) or via Intel Extension for PyTorch. In the latter case, if IPEX is installed, its version
+ must match match current PyTorch version.
+ """
+ return unittest.skipUnless(is_torch_xpu_available(), "test requires XPU device")(test_case)
+
+
+def require_non_xpu(test_case):
+ """
+ Decorator marking a test that should be skipped for XPU.
+ """
+ return unittest.skipUnless(torch_device != "xpu", "test requires a non-XPU")(test_case)
+
+
+def require_torch_multi_xpu(test_case):
+ """
+ Decorator marking a test that requires a multi-XPU setup (in PyTorch). These tests are skipped on a machine without
+ multiple XPUs.
+
+ To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu"
+ """
+ if not is_torch_xpu_available():
+ return unittest.skip(reason="test requires PyTorch XPU")(test_case)
+
+ return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)
+
+
+def require_torch_multi_hpu(test_case):
+ """
+ Decorator marking a test that requires a multi-HPU setup (in PyTorch). These tests are skipped on a machine without
+ multiple HPUs.
+
+ To run *only* the multi_hpu tests, assuming all test names contain multi_hpu: $ pytest -sv ./tests -k "multi_hpu"
+ """
+ if not is_torch_hpu_available():
+ return unittest.skip(reason="test requires PyTorch HPU")(test_case)
+
+ return unittest.skipUnless(torch.hpu.device_count() > 1, "test requires multiple HPUs")(test_case)
+
+
+if is_torch_available():
+ # Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
+ import torch
+
+ if "TRANSFORMERS_TEST_BACKEND" in os.environ:
+ backend = os.environ["TRANSFORMERS_TEST_BACKEND"]
+ try:
+ _ = importlib.import_module(backend)
+ except ModuleNotFoundError as e:
+ raise ModuleNotFoundError(
+ f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its"
+ f" traceback):\n{e}"
+ ) from e
+
+ if "TRANSFORMERS_TEST_DEVICE" in os.environ:
+ torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"]
+ if torch_device == "cuda" and not torch.cuda.is_available():
+ raise ValueError(
+ f"TRANSFORMERS_TEST_DEVICE={torch_device}, but CUDA is unavailable. Please double-check your testing environment."
+ )
+ if torch_device == "xpu" and not is_torch_xpu_available():
+ raise ValueError(
+ f"TRANSFORMERS_TEST_DEVICE={torch_device}, but XPU is unavailable. Please double-check your testing environment."
+ )
+ if torch_device == "npu" and not is_torch_npu_available():
+ raise ValueError(
+ f"TRANSFORMERS_TEST_DEVICE={torch_device}, but NPU is unavailable. Please double-check your testing environment."
+ )
+ if torch_device == "mlu" and not is_torch_mlu_available():
+ raise ValueError(
+ f"TRANSFORMERS_TEST_DEVICE={torch_device}, but MLU is unavailable. Please double-check your testing environment."
+ )
+ if torch_device == "hpu" and not is_torch_hpu_available():
+ raise ValueError(
+ f"TRANSFORMERS_TEST_DEVICE={torch_device}, but HPU is unavailable. Please double-check your testing environment."
+ )
+
+ try:
+ # try creating device to see if provided device is valid
+ _ = torch.device(torch_device)
+ except RuntimeError as e:
+ raise RuntimeError(
+ f"Unknown testing device specified by environment variable `TRANSFORMERS_TEST_DEVICE`: {torch_device}"
+ ) from e
+ elif torch.cuda.is_available():
+ torch_device = "cuda"
+ elif is_torch_npu_available():
+ torch_device = "npu"
+ elif is_torch_mlu_available():
+ torch_device = "mlu"
+ elif is_torch_hpu_available():
+ torch_device = "hpu"
+ elif is_torch_xpu_available():
+ torch_device = "xpu"
+ else:
+ torch_device = "cpu"
+else:
+ torch_device = None
+
+if is_tf_available():
+ import tensorflow as tf
+
+if is_flax_available():
+ import jax
+
+ jax_device = jax.default_backend()
+else:
+ jax_device = None
+
+
+def require_torchdynamo(test_case):
+ """Decorator marking a test that requires TorchDynamo"""
+ return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)
+
+
+def require_torchao(test_case):
+ """Decorator marking a test that requires torchao"""
+ return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case)
+
+
+def require_torchao_version_greater_or_equal(torchao_version):
+ def decorator(test_case):
+ correct_torchao_version = is_torchao_available() and version.parse(
+ version.parse(importlib.metadata.version("torchao")).base_version
+ ) >= version.parse(torchao_version)
+ return unittest.skipUnless(
+ correct_torchao_version, f"Test requires torchao with the version greater than {torchao_version}."
+ )(test_case)
+
+ return decorator
+
+
+def require_torch_tensorrt_fx(test_case):
+ """Decorator marking a test that requires Torch-TensorRT FX"""
+ return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)
+
+
+def require_torch_gpu(test_case):
+ """Decorator marking a test that requires CUDA and PyTorch."""
+ return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
+
+
+def require_torch_mps(test_case):
+ """Decorator marking a test that requires CUDA and PyTorch."""
+ return unittest.skipUnless(torch_device == "mps", "test requires MPS")(test_case)
+
+
+def require_large_cpu_ram(test_case, memory: float = 80):
+ """Decorator marking a test that requires a CPU RAM with more than `memory` GiB of memory."""
+ if not is_psutil_available():
+ return test_case
+
+ import psutil
+
+ return unittest.skipUnless(
+ psutil.virtual_memory().total / 1024**3 > memory,
+ f"test requires a machine with more than {memory} GiB of CPU RAM memory",
+ )(test_case)
+
+
+def require_torch_large_gpu(test_case, memory: float = 20):
+ """Decorator marking a test that requires a CUDA GPU with more than `memory` GiB of memory."""
+ if torch_device != "cuda":
+ return unittest.skip(reason=f"test requires a CUDA GPU with more than {memory} GiB of memory")(test_case)
+
+ return unittest.skipUnless(
+ torch.cuda.get_device_properties(0).total_memory / 1024**3 > memory,
+ f"test requires a GPU with more than {memory} GiB of memory",
+ )(test_case)
+
+
+def require_torch_large_accelerator(test_case, memory: float = 20):
+ """Decorator marking a test that requires an accelerator with more than `memory` GiB of memory."""
+ if torch_device != "cuda" and torch_device != "xpu":
+ return unittest.skip(reason=f"test requires a GPU or XPU with more than {memory} GiB of memory")(test_case)
+
+ torch_accelerator_module = getattr(torch, torch_device)
+
+ return unittest.skipUnless(
+ torch_accelerator_module.get_device_properties(0).total_memory / 1024**3 > memory,
+ f"test requires a GPU or XPU with more than {memory} GiB of memory",
+ )(test_case)
+
+
+def require_torch_gpu_if_bnb_not_multi_backend_enabled(test_case):
+ """
+ Decorator marking a test that requires a GPU if bitsandbytes multi-backend feature is not enabled.
+ """
+ if is_bitsandbytes_available() and is_bitsandbytes_multi_backend_available():
+ return test_case
+ return require_torch_gpu(test_case)
+
+
+def require_torch_accelerator(test_case):
+ """Decorator marking a test that requires an accessible accelerator and PyTorch."""
+ return unittest.skipUnless(torch_device is not None and torch_device != "cpu", "test requires accelerator")(
+ test_case
+ )
+
+
+def require_torch_fp16(test_case):
+ """Decorator marking a test that requires a device that supports fp16"""
+ return unittest.skipUnless(
+ is_torch_fp16_available_on_device(torch_device), "test requires device with fp16 support"
+ )(test_case)
+
+
+def require_fp8(test_case):
+ """Decorator marking a test that requires supports for fp8"""
+ return unittest.skipUnless(is_accelerate_available() and is_fp8_available(), "test requires fp8 support")(
+ test_case
+ )
+
+
+def require_torch_bf16(test_case):
+ """Decorator marking a test that requires a device that supports bf16"""
+ return unittest.skipUnless(
+ is_torch_bf16_available_on_device(torch_device), "test requires device with bf16 support"
+ )(test_case)
+
+
+def require_torch_bf16_gpu(test_case):
+ """Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0"""
+ return unittest.skipUnless(
+ is_torch_bf16_gpu_available(),
+ "test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0",
+ )(test_case)
+
+
+def require_deterministic_for_xpu(test_case):
+ @wraps(test_case)
+ def wrapper(*args, **kwargs):
+ if is_torch_xpu_available():
+ original_state = torch.are_deterministic_algorithms_enabled()
+ try:
+ torch.use_deterministic_algorithms(True)
+ return test_case(*args, **kwargs)
+ finally:
+ torch.use_deterministic_algorithms(original_state)
+ else:
+ return test_case(*args, **kwargs)
+
+ return wrapper
+
+
+def require_torch_tf32(test_case):
+ """Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7."""
+ return unittest.skipUnless(
+ is_torch_tf32_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7"
+ )(test_case)
+
+
+def require_detectron2(test_case):
+ """Decorator marking a test that requires detectron2."""
+ return unittest.skipUnless(is_detectron2_available(), "test requires `detectron2`")(test_case)
+
+
+def require_faiss(test_case):
+ """Decorator marking a test that requires faiss."""
+ return unittest.skipUnless(is_faiss_available(), "test requires `faiss`")(test_case)
+
+
+def require_optuna(test_case):
+ """
+ Decorator marking a test that requires optuna.
+
+ These tests are skipped when optuna isn't installed.
+
+ """
+ return unittest.skipUnless(is_optuna_available(), "test requires optuna")(test_case)
+
+
+def require_ray(test_case):
+ """
+ Decorator marking a test that requires Ray/tune.
+
+ These tests are skipped when Ray/tune isn't installed.
+
+ """
+ return unittest.skipUnless(is_ray_available(), "test requires Ray/tune")(test_case)
+
+
+def require_sigopt(test_case):
+ """
+ Decorator marking a test that requires SigOpt.
+
+ These tests are skipped when SigOpt isn't installed.
+
+ """
+ return unittest.skipUnless(is_sigopt_available(), "test requires SigOpt")(test_case)
+
+
+def require_swanlab(test_case):
+ """
+ Decorator marking a test that requires swanlab.
+
+ These tests are skipped when swanlab isn't installed.
+
+ """
+ return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")(test_case)
+
+
+def require_trackio(test_case):
+ """
+ Decorator marking a test that requires trackio.
+
+ These tests are skipped when trackio isn't installed.
+
+ """
+ return unittest.skipUnless(is_trackio_available(), "test requires trackio")(test_case)
+
+
+def require_wandb(test_case):
+ """
+ Decorator marking a test that requires wandb.
+
+ These tests are skipped when wandb isn't installed.
+
+ """
+ return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case)
+
+
+def require_clearml(test_case):
+ """
+ Decorator marking a test requires clearml.
+
+ These tests are skipped when clearml isn't installed.
+
+ """
+ return unittest.skipUnless(is_clearml_available(), "test requires clearml")(test_case)
+
+
+def require_deepspeed(test_case):
+ """
+ Decorator marking a test that requires deepspeed
+ """
+ return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case)
+
+
+def require_apex(test_case):
+ """
+ Decorator marking a test that requires apex
+ """
+ return unittest.skipUnless(is_apex_available(), "test requires apex")(test_case)
+
+
+def require_aqlm(test_case):
+ """
+ Decorator marking a test that requires aqlm
+ """
+ return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)
+
+
+def require_vptq(test_case):
+ """
+ Decorator marking a test that requires vptq
+ """
+ return unittest.skipUnless(is_vptq_available(), "test requires vptq")(test_case)
+
+
+def require_spqr(test_case):
+ """
+ Decorator marking a test that requires spqr
+ """
+ return unittest.skipUnless(is_spqr_available(), "test requires spqr")(test_case)
+
+
+def require_eetq(test_case):
+ """
+ Decorator marking a test that requires eetq
+ """
+ eetq_available = is_eetq_available()
+ if eetq_available:
+ try:
+ import eetq # noqa: F401
+ except ImportError as exc:
+ if "shard_checkpoint" in str(exc):
+ # EETQ 1.0.0 is currently broken with the latest transformers because it tries to import the removed
+ # shard_checkpoint function, see https://github.com/NetEase-FuXi/EETQ/issues/34.
+ # TODO: Remove once eetq releases a fix and this release is used in CI
+ eetq_available = False
+ return unittest.skipUnless(eetq_available, "test requires eetq")(test_case)
+
+
+def require_av(test_case):
+ """
+ Decorator marking a test that requires av
+ """
+ return unittest.skipUnless(is_av_available(), "test requires av")(test_case)
+
+
+def require_decord(test_case):
+ """
+ Decorator marking a test that requires decord
+ """
+ return unittest.skipUnless(is_decord_available(), "test requires decord")(test_case)
+
+
+def require_bitsandbytes(test_case):
+ """
+ Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library or its hard dependency torch is not installed.
+ """
+ if is_bitsandbytes_available() and is_torch_available():
+ try:
+ import pytest
+
+ return pytest.mark.bitsandbytes(test_case)
+ except ImportError:
+ return test_case
+ else:
+ return unittest.skip(reason="test requires bitsandbytes and torch")(test_case)
+
+
+def require_optimum(test_case):
+ """
+ Decorator for optimum dependency
+ """
+ return unittest.skipUnless(is_optimum_available(), "test requires optimum")(test_case)
+
+
+def require_tensorboard(test_case):
+ """
+ Decorator for `tensorboard` dependency
+ """
+ return unittest.skipUnless(is_tensorboard_available(), "test requires tensorboard")
+
+
+def require_gptq(test_case):
+ """
+ Decorator for auto_gptq dependency
+ """
+ return unittest.skipUnless(
+ is_gptqmodel_available() or is_auto_gptq_available(), "test requires gptqmodel or auto-gptq"
+ )(test_case)
+
+
+def require_hqq(test_case):
+ """
+ Decorator for hqq dependency
+ """
+ return unittest.skipUnless(is_hqq_available(), "test requires hqq")(test_case)
+
+
+def require_auto_awq(test_case):
+ """
+ Decorator for auto_awq dependency
+ """
+ return unittest.skipUnless(is_auto_awq_available(), "test requires autoawq")(test_case)
+
+
+def require_auto_round(test_case):
+ """
+ Decorator for auto_round dependency
+ """
+ return unittest.skipUnless(is_auto_round_available(), "test requires autoround")(test_case)
+
+
+def require_optimum_quanto(test_case):
+ """
+ Decorator for quanto dependency
+ """
+ return unittest.skipUnless(is_optimum_quanto_available(), "test requires optimum-quanto")(test_case)
+
+
+def require_compressed_tensors(test_case):
+ """
+ Decorator for compressed_tensors dependency
+ """
+ return unittest.skipUnless(is_compressed_tensors_available(), "test requires compressed_tensors")(test_case)
+
+
+def require_fbgemm_gpu(test_case):
+ """
+ Decorator for fbgemm_gpu dependency
+ """
+ return unittest.skipUnless(is_fbgemm_gpu_available(), "test requires fbgemm-gpu")(test_case)
+
+
+def require_quark(test_case):
+ """
+ Decorator for quark dependency
+ """
+ return unittest.skipUnless(is_quark_available(), "test requires quark")(test_case)
+
+
+def require_flute_hadamard(test_case):
+ """
+ Decorator marking a test that requires higgs and hadamard
+ """
+ return unittest.skipUnless(
+ is_flute_available() and is_hadamard_available(), "test requires flute and fast_hadamard_transform"
+ )(test_case)
+
+
+def require_fp_quant(test_case):
+ """
+ Decorator marking a test that requires fp_quant and qutlass
+ """
+ return unittest.skipUnless(is_fp_quant_available(), "test requires fp_quant")(test_case)
+
+
+def require_qutlass(test_case):
+ """
+ Decorator marking a test that requires qutlass
+ """
+ return unittest.skipUnless(is_qutlass_available(), "test requires qutlass")(test_case)
+
+
+def require_phonemizer(test_case):
+ """
+ Decorator marking a test that requires phonemizer
+ """
+ return unittest.skipUnless(is_phonemizer_available(), "test requires phonemizer")(test_case)
+
+
+def require_pyctcdecode(test_case):
+ """
+ Decorator marking a test that requires pyctcdecode
+ """
+ return unittest.skipUnless(is_pyctcdecode_available(), "test requires pyctcdecode")(test_case)
+
+
+def require_librosa(test_case):
+ """
+ Decorator marking a test that requires librosa
+ """
+ return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case)
+
+
+def require_liger_kernel(test_case):
+ """
+ Decorator marking a test that requires liger_kernel
+ """
+ return unittest.skipUnless(is_liger_kernel_available(), "test requires liger_kernel")(test_case)
+
+
+def require_essentia(test_case):
+ """
+ Decorator marking a test that requires essentia
+ """
+ return unittest.skipUnless(is_essentia_available(), "test requires essentia")(test_case)
+
+
+def require_pretty_midi(test_case):
+ """
+ Decorator marking a test that requires pretty_midi
+ """
+ return unittest.skipUnless(is_pretty_midi_available(), "test requires pretty_midi")(test_case)
+
+
+def cmd_exists(cmd):
+ return shutil.which(cmd) is not None
+
+
+def require_usr_bin_time(test_case):
+ """
+ Decorator marking a test that requires `/usr/bin/time`
+ """
+ return unittest.skipUnless(cmd_exists("/usr/bin/time"), "test requires /usr/bin/time")(test_case)
+
+
+def require_sudachi(test_case):
+ """
+ Decorator marking a test that requires sudachi
+ """
+ return unittest.skipUnless(is_sudachi_available(), "test requires sudachi")(test_case)
+
+
+def require_sudachi_projection(test_case):
+ """
+ Decorator marking a test that requires sudachi_projection
+ """
+ return unittest.skipUnless(is_sudachi_projection_available(), "test requires sudachi which supports projection")(
+ test_case
+ )
+
+
+def require_jumanpp(test_case):
+ """
+ Decorator marking a test that requires jumanpp
+ """
+ return unittest.skipUnless(is_jumanpp_available(), "test requires jumanpp")(test_case)
+
+
+def require_cython(test_case):
+ """
+ Decorator marking a test that requires jumanpp
+ """
+ return unittest.skipUnless(is_cython_available(), "test requires cython")(test_case)
+
+
+def require_tiktoken(test_case):
+ """
+ Decorator marking a test that requires TikToken. These tests are skipped when TikToken isn't installed.
+ """
+ return unittest.skipUnless(is_tiktoken_available(), "test requires TikToken")(test_case)
+
+
+def require_speech(test_case):
+ """
+ Decorator marking a test that requires speech. These tests are skipped when speech isn't available.
+ """
+ return unittest.skipUnless(is_speech_available(), "test requires torchaudio")(test_case)
+
+
+def require_openai(test_case):
+ """
+ Decorator marking a test that requires openai
+ """
+ return unittest.skipUnless(is_openai_available(), "test requires openai")(test_case)
+
+
+def require_mistral_common(test_case):
+ """
+ Decorator marking a test that requires mistral-common. These tests are skipped when mistral-common isn't available.
+ """
+ return unittest.skipUnless(is_mistral_common_available(), "test requires mistral-common")(test_case)
+
+
+def get_gpu_count():
+ """
+ Return the number of available gpus (regardless of whether torch, tf or jax is used)
+ """
+ if is_torch_available():
+ import torch
+
+ return torch.cuda.device_count()
+ elif is_tf_available():
+ import tensorflow as tf
+
+ return len(tf.config.list_physical_devices("GPU"))
+ elif is_flax_available():
+ import jax
+
+ return jax.device_count()
+ else:
+ return 0
+
+
+def get_tests_dir(append_path=None):
+ """
+ Args:
+ append_path: optional path to append to the tests dir path
+
+ Return:
+ The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is
+ joined after the `tests` dir the former is provided.
+
+ """
+ # this function caller's __file__
+ caller__file__ = inspect.stack()[1][1]
+ tests_dir = os.path.abspath(os.path.dirname(caller__file__))
+
+ while not tests_dir.endswith("tests"):
+ tests_dir = os.path.dirname(tests_dir)
+
+ if append_path:
+ return os.path.join(tests_dir, append_path)
+ else:
+ return tests_dir
+
+
+def get_steps_per_epoch(trainer: Trainer) -> int:
+ training_args = trainer.args
+ train_dataloader = trainer.get_train_dataloader()
+
+ initial_training_values = trainer.set_initial_training_values(
+ args=training_args,
+ dataloader=train_dataloader,
+ total_train_batch_size=training_args.per_device_train_batch_size,
+ )
+ steps_per_epoch = initial_training_values[1]
+
+ return steps_per_epoch
+
+
+def evaluate_side_effect_factory(
+ side_effect_values: list[dict[str, float]],
+) -> Generator[dict[str, float], None, None]:
+ """
+ Function that returns side effects for the _evaluate method.
+ Used when we're unsure of exactly how many times _evaluate will be called.
+ """
+ yield from side_effect_values
+
+ while True:
+ yield side_effect_values[-1]
+
+
+#
+# Helper functions for dealing with testing text outputs
+# The original code came from:
+# https://github.com/fastai/fastai/blob/master/tests/utils/text.py
+
+
+# When any function contains print() calls that get overwritten, like progress bars,
+# a special care needs to be applied, since under pytest -s captured output (capsys
+# or contextlib.redirect_stdout) contains any temporary printed strings, followed by
+# \r's. This helper function ensures that the buffer will contain the same output
+# with and without -s in pytest, by turning:
+# foo bar\r tar mar\r final message
+# into:
+# final message
+# it can handle a single string or a multiline buffer
+def apply_print_resets(buf):
+ return re.sub(r"^.*\r", "", buf, 0, re.MULTILINE)
+
+
+def assert_screenout(out, what):
+ out_pr = apply_print_resets(out).lower()
+ match_str = out_pr.find(what.lower())
+ assert match_str != -1, f"expecting to find {what} in output: f{out_pr}"
+
+
+def set_config_for_less_flaky_test(config):
+ target_attrs = [
+ "rms_norm_eps",
+ "layer_norm_eps",
+ "norm_eps",
+ "norm_epsilon",
+ "layer_norm_epsilon",
+ "batch_norm_eps",
+ ]
+ for target_attr in target_attrs:
+ setattr(config, target_attr, 1.0)
+
+ # norm layers (layer/group norm, etc.) could cause flaky tests when the tensors have very small variance.
+ # (We don't need the original epsilon values to check eager/sdpa matches)
+ attrs = ["text_config", "vision_config", "text_encoder", "audio_encoder", "decoder"]
+ for attr in attrs:
+ if hasattr(config, attr):
+ for target_attr in target_attrs:
+ setattr(getattr(config, attr), target_attr, 1.0)
+
+
+def set_model_for_less_flaky_test(model):
+ # Another way to make sure norm layers have desired epsilon. (Some models don't set it from its config.)
+ target_names = (
+ "LayerNorm",
+ "GroupNorm",
+ "BatchNorm",
+ "RMSNorm",
+ "BatchNorm2d",
+ "BatchNorm1d",
+ "BitGroupNormActivation",
+ "WeightStandardizedConv2d",
+ )
+ target_attrs = ["eps", "epsilon", "variance_epsilon"]
+ if is_torch_available() and isinstance(model, torch.nn.Module):
+ for module in model.modules():
+ if type(module).__name__.endswith(target_names):
+ for attr in target_attrs:
+ if hasattr(module, attr):
+ setattr(module, attr, 1.0)
+
+
+class CaptureStd:
+ """
+ Context manager to capture:
+
+ - stdout: replay it, clean it up and make it available via `obj.out`
+ - stderr: replay it and make it available via `obj.err`
+
+ Args:
+ out (`bool`, *optional*, defaults to `True`): Whether to capture stdout or not.
+ err (`bool`, *optional*, defaults to `True`): Whether to capture stderr or not.
+ replay (`bool`, *optional*, defaults to `True`): Whether to replay or not.
+ By default each captured stream gets replayed back on context's exit, so that one can see what the test was
+ doing. If this is a not wanted behavior and the captured data shouldn't be replayed, pass `replay=False` to
+ disable this feature.
+
+ Examples:
+
+ ```python
+ # to capture stdout only with auto-replay
+ with CaptureStdout() as cs:
+ print("Secret message")
+ assert "message" in cs.out
+
+ # to capture stderr only with auto-replay
+ import sys
+
+ with CaptureStderr() as cs:
+ print("Warning: ", file=sys.stderr)
+ assert "Warning" in cs.err
+
+ # to capture both streams with auto-replay
+ with CaptureStd() as cs:
+ print("Secret message")
+ print("Warning: ", file=sys.stderr)
+ assert "message" in cs.out
+ assert "Warning" in cs.err
+
+ # to capture just one of the streams, and not the other, with auto-replay
+ with CaptureStd(err=False) as cs:
+ print("Secret message")
+ assert "message" in cs.out
+ # but best use the stream-specific subclasses
+
+ # to capture without auto-replay
+ with CaptureStd(replay=False) as cs:
+ print("Secret message")
+ assert "message" in cs.out
+ ```"""
+
+ def __init__(self, out=True, err=True, replay=True):
+ self.replay = replay
+
+ if out:
+ self.out_buf = StringIO()
+ self.out = "error: CaptureStd context is unfinished yet, called too early"
+ else:
+ self.out_buf = None
+ self.out = "not capturing stdout"
+
+ if err:
+ self.err_buf = StringIO()
+ self.err = "error: CaptureStd context is unfinished yet, called too early"
+ else:
+ self.err_buf = None
+ self.err = "not capturing stderr"
+
+ def __enter__(self):
+ if self.out_buf:
+ self.out_old = sys.stdout
+ sys.stdout = self.out_buf
+
+ if self.err_buf:
+ self.err_old = sys.stderr
+ sys.stderr = self.err_buf
+
+ return self
+
+ def __exit__(self, *exc):
+ if self.out_buf:
+ sys.stdout = self.out_old
+ captured = self.out_buf.getvalue()
+ if self.replay:
+ sys.stdout.write(captured)
+ self.out = apply_print_resets(captured)
+
+ if self.err_buf:
+ sys.stderr = self.err_old
+ captured = self.err_buf.getvalue()
+ if self.replay:
+ sys.stderr.write(captured)
+ self.err = captured
+
+ def __repr__(self):
+ msg = ""
+ if self.out_buf:
+ msg += f"stdout: {self.out}\n"
+ if self.err_buf:
+ msg += f"stderr: {self.err}\n"
+ return msg
+
+
+# in tests it's the best to capture only the stream that's wanted, otherwise
+# it's easy to miss things, so unless you need to capture both streams, use the
+# subclasses below (less typing). Or alternatively, configure `CaptureStd` to
+# disable the stream you don't need to test.
+
+
+class CaptureStdout(CaptureStd):
+ """Same as CaptureStd but captures only stdout"""
+
+ def __init__(self, replay=True):
+ super().__init__(err=False, replay=replay)
+
+
+class CaptureStderr(CaptureStd):
+ """Same as CaptureStd but captures only stderr"""
+
+ def __init__(self, replay=True):
+ super().__init__(out=False, replay=replay)
+
+
+class CaptureLogger:
+ """
+ Context manager to capture `logging` streams
+
+ Args:
+ logger: 'logging` logger object
+
+ Returns:
+ The captured output is available via `self.out`
+
+ Example:
+
+ ```python
+ >>> from transformers import logging
+ >>> from transformers.testing_utils import CaptureLogger
+
+ >>> msg = "Testing 1, 2, 3"
+ >>> logging.set_verbosity_info()
+ >>> logger = logging.get_logger("transformers.models.bart.tokenization_bart")
+ >>> with CaptureLogger(logger) as cl:
+ ... logger.info(msg)
+ >>> assert cl.out, msg + "\n"
+ ```
+ """
+
+ def __init__(self, logger):
+ self.logger = logger
+ self.io = StringIO()
+ self.sh = logging.StreamHandler(self.io)
+ self.out = ""
+
+ def __enter__(self):
+ self.logger.addHandler(self.sh)
+ return self
+
+ def __exit__(self, *exc):
+ self.logger.removeHandler(self.sh)
+ self.out = self.io.getvalue()
+
+ def __repr__(self):
+ return f"captured: {self.out}\n"
+
+
+@contextlib.contextmanager
+def LoggingLevel(level):
+ """
+ This is a context manager to temporarily change transformers modules logging level to the desired value and have it
+ restored to the original setting at the end of the scope.
+
+ Example:
+
+ ```python
+ with LoggingLevel(logging.INFO):
+ AutoModel.from_pretrained("openai-community/gpt2") # calls logger.info() several times
+ ```
+ """
+ orig_level = transformers_logging.get_verbosity()
+ try:
+ transformers_logging.set_verbosity(level)
+ yield
+ finally:
+ transformers_logging.set_verbosity(orig_level)
+
+
+class TemporaryHubRepo:
+ """Create a temporary Hub repository and return its `RepoUrl` object. This is similar to
+ `tempfile.TemporaryDirectory` and can be used as a context manager. For example:
+
+ with TemporaryHubRepo(token=self._token) as temp_repo:
+ ...
+
+ Upon exiting the context, the repository and everything contained in it are removed.
+
+ Example:
+
+ ```python
+ with TemporaryHubRepo(token=self._token) as temp_repo:
+ model.push_to_hub(tmp_repo.repo_id, token=self._token)
+ ```
+ """
+
+ def __init__(self, namespace: Optional[str] = None, token: Optional[str] = None) -> None:
+ self.token = token
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ repo_id = Path(tmp_dir).name
+ if namespace is not None:
+ repo_id = f"{namespace}/{repo_id}"
+ self.repo_url = huggingface_hub.create_repo(repo_id, token=self.token)
+
+ def __enter__(self):
+ return self.repo_url
+
+ def __exit__(self, exc, value, tb):
+ delete_repo(repo_id=self.repo_url.repo_id, token=self.token, missing_ok=True)
+
+
+@contextlib.contextmanager
+# adapted from https://stackoverflow.com/a/64789046/9201239
+def ExtendSysPath(path: Union[str, os.PathLike]) -> Iterator[None]:
+ """
+ Temporary add given path to `sys.path`.
+
+ Usage :
+
+ ```python
+ with ExtendSysPath("/path/to/dir"):
+ mymodule = importlib.import_module("mymodule")
+ ```
+ """
+
+ path = os.fspath(path)
+ try:
+ sys.path.insert(0, path)
+ yield
+ finally:
+ sys.path.remove(path)
+
+
+class TestCasePlus(unittest.TestCase):
+ """
+ This class extends *unittest.TestCase* with additional features.
+
+ Feature 1: A set of fully resolved important file and dir path accessors.
+
+ In tests often we need to know where things are relative to the current test file, and it's not trivial since the
+ test could be invoked from more than one directory or could reside in sub-directories with different depths. This
+ class solves this problem by sorting out all the basic paths and provides easy accessors to them:
+
+ - `pathlib` objects (all fully resolved):
+
+ - `test_file_path` - the current test file path (=`__file__`)
+ - `test_file_dir` - the directory containing the current test file
+ - `tests_dir` - the directory of the `tests` test suite
+ - `examples_dir` - the directory of the `examples` test suite
+ - `repo_root_dir` - the directory of the repository
+ - `src_dir` - the directory of `src` (i.e. where the `transformers` sub-dir resides)
+
+ - stringified paths---same as above but these return paths as strings, rather than `pathlib` objects:
+
+ - `test_file_path_str`
+ - `test_file_dir_str`
+ - `tests_dir_str`
+ - `examples_dir_str`
+ - `repo_root_dir_str`
+ - `src_dir_str`
+
+ Feature 2: Flexible auto-removable temporary dirs which are guaranteed to get removed at the end of test.
+
+ 1. Create a unique temporary dir:
+
+ ```python
+ def test_whatever(self):
+ tmp_dir = self.get_auto_remove_tmp_dir()
+ ```
+
+ `tmp_dir` will contain the path to the created temporary dir. It will be automatically removed at the end of the
+ test.
+
+
+ 2. Create a temporary dir of my choice, ensure it's empty before the test starts and don't
+ empty it after the test.
+
+ ```python
+ def test_whatever(self):
+ tmp_dir = self.get_auto_remove_tmp_dir("./xxx")
+ ```
+
+ This is useful for debug when you want to monitor a specific directory and want to make sure the previous tests
+ didn't leave any data in there.
+
+ 3. You can override the first two options by directly overriding the `before` and `after` args, leading to the
+ following behavior:
+
+ `before=True`: the temporary dir will always be cleared at the beginning of the test.
+
+ `before=False`: if the temporary dir already existed, any existing files will remain there.
+
+ `after=True`: the temporary dir will always be deleted at the end of the test.
+
+ `after=False`: the temporary dir will always be left intact at the end of the test.
+
+ Note 1: In order to run the equivalent of `rm -r` safely, only subdirs of the project repository checkout are
+ allowed if an explicit `tmp_dir` is used, so that by mistake no `/tmp` or similar important part of the filesystem
+ will get nuked. i.e. please always pass paths that start with `./`
+
+ Note 2: Each test can register multiple temporary dirs and they all will get auto-removed, unless requested
+ otherwise.
+
+ Feature 3: Get a copy of the `os.environ` object that sets up `PYTHONPATH` specific to the current test suite. This
+ is useful for invoking external programs from the test suite - e.g. distributed training.
+
+
+ ```python
+ def test_whatever(self):
+ env = self.get_env()
+ ```"""
+
+ def setUp(self):
+ # get_auto_remove_tmp_dir feature:
+ self.teardown_tmp_dirs = []
+
+ # figure out the resolved paths for repo_root, tests, examples, etc.
+ self._test_file_path = inspect.getfile(self.__class__)
+ path = Path(self._test_file_path).resolve()
+ self._test_file_dir = path.parents[0]
+ for up in [1, 2, 3]:
+ tmp_dir = path.parents[up]
+ if (tmp_dir / "src").is_dir() and (tmp_dir / "tests").is_dir():
+ break
+ if tmp_dir:
+ self._repo_root_dir = tmp_dir
+ else:
+ raise ValueError(f"can't figure out the root of the repo from {self._test_file_path}")
+ self._tests_dir = self._repo_root_dir / "tests"
+ self._examples_dir = self._repo_root_dir / "examples"
+ self._src_dir = self._repo_root_dir / "src"
+
+ @property
+ def test_file_path(self):
+ return self._test_file_path
+
+ @property
+ def test_file_path_str(self):
+ return str(self._test_file_path)
+
+ @property
+ def test_file_dir(self):
+ return self._test_file_dir
+
+ @property
+ def test_file_dir_str(self):
+ return str(self._test_file_dir)
+
+ @property
+ def tests_dir(self):
+ return self._tests_dir
+
+ @property
+ def tests_dir_str(self):
+ return str(self._tests_dir)
+
+ @property
+ def examples_dir(self):
+ return self._examples_dir
+
+ @property
+ def examples_dir_str(self):
+ return str(self._examples_dir)
+
+ @property
+ def repo_root_dir(self):
+ return self._repo_root_dir
+
+ @property
+ def repo_root_dir_str(self):
+ return str(self._repo_root_dir)
+
+ @property
+ def src_dir(self):
+ return self._src_dir
+
+ @property
+ def src_dir_str(self):
+ return str(self._src_dir)
+
+ def get_env(self):
+ """
+ Return a copy of the `os.environ` object that sets up `PYTHONPATH` correctly, depending on the test suite it's
+ invoked from. This is useful for invoking external programs from the test suite - e.g. distributed training.
+
+ It always inserts `./src` first, then `./tests` or `./examples` depending on the test suite type and finally
+ the preset `PYTHONPATH` if any (all full resolved paths).
+
+ """
+ env = os.environ.copy()
+ paths = [self.repo_root_dir_str, self.src_dir_str]
+ if "/examples" in self.test_file_dir_str:
+ paths.append(self.examples_dir_str)
+ else:
+ paths.append(self.tests_dir_str)
+ paths.append(env.get("PYTHONPATH", ""))
+
+ env["PYTHONPATH"] = ":".join(paths)
+ return env
+
+ def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
+ """
+ Args:
+ tmp_dir (`string`, *optional*):
+ if `None`:
+
+ - a unique temporary path will be created
+ - sets `before=True` if `before` is `None`
+ - sets `after=True` if `after` is `None`
+ else:
+
+ - `tmp_dir` will be created
+ - sets `before=True` if `before` is `None`
+ - sets `after=False` if `after` is `None`
+ before (`bool`, *optional*):
+ If `True` and the `tmp_dir` already exists, make sure to empty it right away if `False` and the
+ `tmp_dir` already exists, any existing files will remain there.
+ after (`bool`, *optional*):
+ If `True`, delete the `tmp_dir` at the end of the test if `False`, leave the `tmp_dir` and its contents
+ intact at the end of the test.
+
+ Returns:
+ tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir
+ """
+ if tmp_dir is not None:
+ # defining the most likely desired behavior for when a custom path is provided.
+ # this most likely indicates the debug mode where we want an easily locatable dir that:
+ # 1. gets cleared out before the test (if it already exists)
+ # 2. is left intact after the test
+ if before is None:
+ before = True
+ if after is None:
+ after = False
+
+ # using provided path
+ path = Path(tmp_dir).resolve()
+
+ # to avoid nuking parts of the filesystem, only relative paths are allowed
+ if not tmp_dir.startswith("./"):
+ raise ValueError(
+ f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`"
+ )
+
+ # ensure the dir is empty to start with
+ if before is True and path.exists():
+ shutil.rmtree(tmp_dir, ignore_errors=True)
+
+ path.mkdir(parents=True, exist_ok=True)
+
+ else:
+ # defining the most likely desired behavior for when a unique tmp path is auto generated
+ # (not a debug mode), here we require a unique tmp dir that:
+ # 1. is empty before the test (it will be empty in this situation anyway)
+ # 2. gets fully removed after the test
+ if before is None:
+ before = True
+ if after is None:
+ after = True
+
+ # using unique tmp dir (always empty, regardless of `before`)
+ tmp_dir = tempfile.mkdtemp()
+
+ if after is True:
+ # register for deletion
+ self.teardown_tmp_dirs.append(tmp_dir)
+
+ return tmp_dir
+
+ def python_one_liner_max_rss(self, one_liner_str):
+ """
+ Runs the passed python one liner (just the code) and returns how much max cpu memory was used to run the
+ program.
+
+ Args:
+ one_liner_str (`string`):
+ a python one liner code that gets passed to `python -c`
+
+ Returns:
+ max cpu memory bytes used to run the program. This value is likely to vary slightly from run to run.
+
+ Requirements:
+ this helper needs `/usr/bin/time` to be installed (`apt install time`)
+
+ Example:
+
+ ```
+ one_liner_str = 'from transformers import AutoModel; AutoModel.from_pretrained("google-t5/t5-large")'
+ max_rss = self.python_one_liner_max_rss(one_liner_str)
+ ```
+ """
+
+ if not cmd_exists("/usr/bin/time"):
+ raise ValueError("/usr/bin/time is required, install with `apt install time`")
+
+ cmd = shlex.split(f"/usr/bin/time -f %M python -c '{one_liner_str}'")
+ with CaptureStd() as cs:
+ execute_subprocess_async(cmd, env=self.get_env())
+ # returned data is in KB so convert to bytes
+ max_rss = int(cs.err.split("\n")[-2].replace("stderr: ", "")) * 1024
+ return max_rss
+
+ def tearDown(self):
+ # get_auto_remove_tmp_dir feature: remove registered temp dirs
+ for path in self.teardown_tmp_dirs:
+ shutil.rmtree(path, ignore_errors=True)
+ self.teardown_tmp_dirs = []
+ if is_accelerate_available():
+ AcceleratorState._reset_state()
+ PartialState._reset_state()
+
+ # delete all the env variables having `ACCELERATE` in them
+ for k in list(os.environ.keys()):
+ if "ACCELERATE" in k:
+ del os.environ[k]
+
+
+def mockenv(**kwargs):
+ """
+ this is a convenience wrapper, that allows this ::
+
+ @mockenv(RUN_SLOW=True, USE_TF=False) def test_something():
+ run_slow = os.getenv("RUN_SLOW", False) use_tf = os.getenv("USE_TF", False)
+
+ """
+ return mock.patch.dict(os.environ, kwargs)
+
+
+# from https://stackoverflow.com/a/34333710/9201239
+@contextlib.contextmanager
+def mockenv_context(*remove, **update):
+ """
+ Temporarily updates the `os.environ` dictionary in-place. Similar to mockenv
+
+ The `os.environ` dictionary is updated in-place so that the modification is sure to work in all situations.
+
+ Args:
+ remove: Environment variables to remove.
+ update: Dictionary of environment variables and values to add/update.
+ """
+ env = os.environ
+ update = update or {}
+ remove = remove or []
+
+ # List of environment variables being updated or removed.
+ stomped = (set(update.keys()) | set(remove)) & set(env.keys())
+ # Environment variables and values to restore on exit.
+ update_after = {k: env[k] for k in stomped}
+ # Environment variables and values to remove on exit.
+ remove_after = frozenset(k for k in update if k not in env)
+
+ try:
+ env.update(update)
+ [env.pop(k, None) for k in remove]
+ yield
+ finally:
+ env.update(update_after)
+ [env.pop(k) for k in remove_after]
+
+
+# --- pytest conf functions --- #
+
+# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once
+pytest_opt_registered = {}
+
+
+def pytest_addoption_shared(parser):
+ """
+ This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there.
+
+ It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest`
+ option.
+
+ """
+ option = "--make-reports"
+ if option not in pytest_opt_registered:
+ parser.addoption(
+ option,
+ action="store",
+ default=False,
+ help="generate report files. The value of this option is used as a prefix to report names",
+ )
+ pytest_opt_registered[option] = 1
+
+
+def pytest_terminal_summary_main(tr, id):
+ """
+ Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current
+ directory. The report files are prefixed with the test suite name.
+
+ This function emulates --duration and -rA pytest arguments.
+
+ This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined
+ there.
+
+ Args:
+ - tr: `terminalreporter` passed from `conftest.py`
+ - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is
+ needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other.
+
+ NB: this functions taps into a private _pytest API and while unlikely, it could break should pytest do internal
+ changes - also it calls default internal methods of terminalreporter which can be hijacked by various `pytest-`
+ plugins and interfere.
+
+ """
+ from _pytest.config import create_terminal_writer
+
+ if not len(id):
+ id = "tests"
+
+ config = tr.config
+ orig_writer = config.get_terminal_writer()
+ orig_tbstyle = config.option.tbstyle
+ orig_reportchars = tr.reportchars
+
+ dir = f"reports/{id}"
+ Path(dir).mkdir(parents=True, exist_ok=True)
+ report_files = {
+ k: f"{dir}/{k}.txt"
+ for k in [
+ "durations",
+ "errors",
+ "failures_long",
+ "failures_short",
+ "failures_line",
+ "passes",
+ "stats",
+ "summary_short",
+ "warnings",
+ ]
+ }
+
+ # custom durations report
+ # note: there is no need to call pytest --durations=XX to get this separate report
+ # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66
+ dlist = []
+ for replist in tr.stats.values():
+ for rep in replist:
+ if hasattr(rep, "duration"):
+ dlist.append(rep)
+ if dlist:
+ dlist.sort(key=lambda x: x.duration, reverse=True)
+ with open(report_files["durations"], "w") as f:
+ durations_min = 0.05 # sec
+ f.write("slowest durations\n")
+ for i, rep in enumerate(dlist):
+ if rep.duration < durations_min:
+ f.write(f"{len(dlist) - i} durations < {durations_min} secs were omitted")
+ break
+ f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
+
+ def summary_failures_short(tr):
+ # expecting that the reports were --tb=long (default) so we chop them off here to the last frame
+ reports = tr.getreports("failed")
+ if not reports:
+ return
+ tr.write_sep("=", "FAILURES SHORT STACK")
+ for rep in reports:
+ msg = tr._getfailureheadline(rep)
+ tr.write_sep("_", msg, red=True, bold=True)
+ # chop off the optional leading extra frames, leaving only the last one
+ longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.MULTILINE | re.DOTALL)
+ tr._tw.line(longrepr)
+ # note: not printing out any rep.sections to keep the report short
+
+ # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each
+ # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814
+ # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g.
+ # pytest-instafail does that)
+
+ # report failures with line/short/long styles
+ config.option.tbstyle = "auto" # full tb
+ with open(report_files["failures_long"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_failures()
+
+ # config.option.tbstyle = "short" # short tb
+ with open(report_files["failures_short"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ summary_failures_short(tr)
+
+ config.option.tbstyle = "line" # one line per error
+ with open(report_files["failures_line"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_failures()
+
+ with open(report_files["errors"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_errors()
+
+ with open(report_files["warnings"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_warnings() # normal warnings
+ tr.summary_warnings() # final warnings
+
+ tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary())
+
+ # Skip the `passes` report, as it starts to take more than 5 minutes, and sometimes it timeouts on CircleCI if it
+ # takes > 10 minutes (as this part doesn't generate any output on the terminal).
+ # (also, it seems there is no useful information in this report, and we rarely need to read it)
+ # with open(report_files["passes"], "w") as f:
+ # tr._tw = create_terminal_writer(config, f)
+ # tr.summary_passes()
+
+ with open(report_files["summary_short"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.short_test_summary()
+
+ with open(report_files["stats"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_stats()
+
+ # restore:
+ tr._tw = orig_writer
+ tr.reportchars = orig_reportchars
+ config.option.tbstyle = orig_tbstyle
+
+
+# --- distributed testing functions --- #
+
+# adapted from https://stackoverflow.com/a/59041913/9201239
+import asyncio # noqa
+
+
+class _RunOutput:
+ def __init__(self, returncode, stdout, stderr):
+ self.returncode = returncode
+ self.stdout = stdout
+ self.stderr = stderr
+
+
+async def _read_stream(stream, callback):
+ while True:
+ line = await stream.readline()
+ if line:
+ callback(line)
+ else:
+ break
+
+
+async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
+ if echo:
+ print("\nRunning: ", " ".join(cmd))
+
+ p = await asyncio.create_subprocess_exec(
+ cmd[0],
+ *cmd[1:],
+ stdin=stdin,
+ stdout=asyncio.subprocess.PIPE,
+ stderr=asyncio.subprocess.PIPE,
+ env=env,
+ )
+
+ # note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
+ # https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
+ #
+ # If it starts hanging, will need to switch to the following code. The problem is that no data
+ # will be seen until it's done and if it hangs for example there will be no debug info.
+ # out, err = await p.communicate()
+ # return _RunOutput(p.returncode, out, err)
+
+ out = []
+ err = []
+
+ def tee(line, sink, pipe, label=""):
+ line = line.decode("utf-8").rstrip()
+ sink.append(line)
+ if not quiet:
+ print(label, line, file=pipe)
+
+ # XXX: the timeout doesn't seem to make any difference here
+ await asyncio.wait(
+ [
+ asyncio.create_task(_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:"))),
+ asyncio.create_task(_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:"))),
+ ],
+ timeout=timeout,
+ )
+ return _RunOutput(await p.wait(), out, err)
+
+
+def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput:
+ loop = asyncio.get_event_loop()
+ result = loop.run_until_complete(
+ _stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
+ )
+
+ cmd_str = " ".join(cmd)
+ if result.returncode > 0:
+ stderr = "\n".join(result.stderr)
+ raise RuntimeError(
+ f"'{cmd_str}' failed with returncode {result.returncode}\n\n"
+ f"The combined stderr from workers follows:\n{stderr}"
+ )
+
+ # check that the subprocess actually did run and produced some output, should the test rely on
+ # the remote side to do the testing
+ if not result.stdout and not result.stderr:
+ raise RuntimeError(f"'{cmd_str}' produced no output.")
+
+ return result
+
+
+def pytest_xdist_worker_id():
+ """
+ Returns an int value of worker's numerical id under `pytest-xdist`'s concurrent workers `pytest -n N` regime, or 0
+ if `-n 1` or `pytest-xdist` isn't being used.
+ """
+ worker = os.environ.get("PYTEST_XDIST_WORKER", "gw0")
+ worker = re.sub(r"^gw", "", worker, 0, re.MULTILINE)
+ return int(worker)
+
+
+def get_torch_dist_unique_port():
+ """
+ Returns a port number that can be fed to `torch.distributed.launch`'s `--master_port` argument.
+
+ Under `pytest-xdist` it adds a delta number based on a worker id so that concurrent tests don't try to use the same
+ port at once.
+ """
+ port = 29500
+ uniq_delta = pytest_xdist_worker_id()
+ return port + uniq_delta
+
+
+def nested_simplify(obj, decimals=3):
+ """
+ Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test
+ within tests.
+ """
+ import numpy as np
+
+ if isinstance(obj, list):
+ return [nested_simplify(item, decimals) for item in obj]
+ if isinstance(obj, tuple):
+ return tuple(nested_simplify(item, decimals) for item in obj)
+ elif isinstance(obj, np.ndarray):
+ return nested_simplify(obj.tolist())
+ elif isinstance(obj, Mapping):
+ return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
+ elif isinstance(obj, (str, int, np.int64)) or obj is None:
+ return obj
+ elif is_torch_available() and isinstance(obj, torch.Tensor):
+ return nested_simplify(obj.tolist(), decimals)
+ elif is_tf_available() and tf.is_tensor(obj):
+ return nested_simplify(obj.numpy().tolist())
+ elif isinstance(obj, float):
+ return round(obj, decimals)
+ elif isinstance(obj, (np.int32, np.float32, np.float16)):
+ return nested_simplify(obj.item(), decimals)
+ else:
+ raise Exception(f"Not supported: {type(obj)}")
+
+
+def check_json_file_has_correct_format(file_path):
+ with open(file_path) as f:
+ lines = f.readlines()
+ if len(lines) == 1:
+ # length can only be 1 if dict is empty
+ assert lines[0] == "{}"
+ else:
+ # otherwise make sure json has correct format (at least 3 lines)
+ assert len(lines) >= 3
+ # each key one line, ident should be 2, min length is 3
+ assert lines[0].strip() == "{"
+ for line in lines[1:-1]:
+ left_indent = len(lines[1]) - len(lines[1].lstrip())
+ assert left_indent == 2
+ assert lines[-1].strip() == "}"
+
+
+def to_2tuple(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return (x, x)
+
+
+# These utils relate to ensuring the right error message is received when running scripts
+class SubprocessCallException(Exception):
+ pass
+
+
+def run_command(command: list[str], return_stdout=False):
+ """
+ Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
+ if an error occurred while running `command`
+ """
+ try:
+ output = subprocess.check_output(command, stderr=subprocess.STDOUT)
+ if return_stdout:
+ if hasattr(output, "decode"):
+ output = output.decode("utf-8")
+ return output
+ except subprocess.CalledProcessError as e:
+ raise SubprocessCallException(
+ f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
+ ) from e
+
+
+class RequestCounter:
+ """
+ Helper class that will count all requests made online.
+
+ Might not be robust if urllib3 changes its logging format but should be good enough for us.
+
+ Usage:
+ ```py
+ with RequestCounter() as counter:
+ _ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
+ assert counter["GET"] == 0
+ assert counter["HEAD"] == 1
+ assert counter.total_calls == 1
+ ```
+ """
+
+ def __enter__(self):
+ self._counter = defaultdict(int)
+ self._thread_id = threading.get_ident()
+ self._extra_info = []
+
+ def patched_with_thread_info(func):
+ def wrap(*args, **kwargs):
+ self._extra_info.append(threading.get_ident())
+ return func(*args, **kwargs)
+
+ return wrap
+
+ self.patcher = patch.object(
+ urllib3.connectionpool.log, "debug", side_effect=patched_with_thread_info(urllib3.connectionpool.log.debug)
+ )
+ self.mock = self.patcher.start()
+ return self
+
+ def __exit__(self, *args, **kwargs) -> None:
+ assert len(self.mock.call_args_list) == len(self._extra_info)
+ for thread_id, call in zip(self._extra_info, self.mock.call_args_list):
+ if thread_id != self._thread_id:
+ continue
+ # code 307: the URL being requested by the user has moved to a temporary location
+ if call.args[-2] == 307:
+ continue
+ log = call.args[0] % call.args[1:]
+ for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"):
+ if method in log:
+ self._counter[method] += 1
+ break
+ self.patcher.stop()
+
+ def __getitem__(self, key: str) -> int:
+ return self._counter[key]
+
+ @property
+ def total_calls(self) -> int:
+ return sum(self._counter.values())
+
+
+def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
+ """
+ To decorate flaky tests. They will be retried on failures.
+
+ Please note that our push tests use `pytest-rerunfailures`, which prompts the CI to rerun certain types of
+ failed tests. More specifically, if the test exception contains any substring in `FLAKY_TEST_FAILURE_PATTERNS`
+ (in `.circleci/create_circleci_config.py`), it will be rerun. If you find a recurrent pattern of failures,
+ expand `FLAKY_TEST_FAILURE_PATTERNS` in our CI configuration instead of using `is_flaky`.
+
+ Args:
+ max_attempts (`int`, *optional*, defaults to 5):
+ The maximum number of attempts to retry the flaky test.
+ wait_before_retry (`float`, *optional*):
+ If provided, will wait that number of seconds before retrying the test.
+ description (`str`, *optional*):
+ A string to describe the situation (what / where / why is flaky, link to GH issue/PR comments, errors,
+ etc.)
+ """
+
+ def decorator(test_func_ref):
+ @functools.wraps(test_func_ref)
+ def wrapper(*args, **kwargs):
+ retry_count = 1
+
+ while retry_count < max_attempts:
+ try:
+ return test_func_ref(*args, **kwargs)
+
+ except Exception as err:
+ logger.error(f"Test failed with {err} at try {retry_count}/{max_attempts}.")
+ if wait_before_retry is not None:
+ time.sleep(wait_before_retry)
+ retry_count += 1
+
+ return test_func_ref(*args, **kwargs)
+
+ return unittest.skipUnless(_run_flaky_tests, "test is flaky")(wrapper)
+
+ return decorator
+
+
+def hub_retry(max_attempts: int = 5, wait_before_retry: Optional[float] = 2):
+ """
+ To decorate tests that download from the Hub. They can fail due to a
+ variety of network issues such as timeouts, connection resets, etc.
+
+ Args:
+ max_attempts (`int`, *optional*, defaults to 5):
+ The maximum number of attempts to retry the flaky test.
+ wait_before_retry (`float`, *optional*, defaults to 2):
+ If provided, will wait that number of seconds before retrying the test.
+ """
+
+ def decorator(test_func_ref):
+ @functools.wraps(test_func_ref)
+ def wrapper(*args, **kwargs):
+ retry_count = 1
+
+ while retry_count < max_attempts:
+ try:
+ return test_func_ref(*args, **kwargs)
+ # We catch all exceptions related to network issues from requests
+ except (
+ requests.exceptions.ConnectionError,
+ requests.exceptions.Timeout,
+ requests.exceptions.ReadTimeout,
+ requests.exceptions.HTTPError,
+ requests.exceptions.RequestException,
+ ) as err:
+ logger.error(
+ f"Test failed with {err} at try {retry_count}/{max_attempts} as it couldn't connect to the specified Hub repository."
+ )
+ if wait_before_retry is not None:
+ time.sleep(wait_before_retry)
+ retry_count += 1
+
+ return test_func_ref(*args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+def run_first(test_case):
+ """
+ Decorator marking a test with order(1). When pytest-order plugin is installed, tests marked with this decorator
+ are guaranteed to run first.
+
+ This is especially useful in some test settings like on a Gaudi instance where a Gaudi device can only be used by a
+ single process at a time. So we make sure all tests that run in a subprocess are launched first, to avoid device
+ allocation conflicts.
+ """
+ import pytest
+
+ return pytest.mark.order(1)(test_case)
+
+
+def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
+ """
+ To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
+
+ Args:
+ test_case (`unittest.TestCase`):
+ The test that will run `target_func`.
+ target_func (`Callable`):
+ The function implementing the actual testing logic.
+ inputs (`dict`, *optional*, defaults to `None`):
+ The inputs that will be passed to `target_func` through an (input) queue.
+ timeout (`int`, *optional*, defaults to `None`):
+ The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env.
+ variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`.
+ """
+ if timeout is None:
+ timeout = int(os.environ.get("PYTEST_TIMEOUT", "600"))
+
+ start_methohd = "spawn"
+ ctx = multiprocessing.get_context(start_methohd)
+
+ input_queue = ctx.Queue(1)
+ output_queue = ctx.JoinableQueue(1)
+
+ # We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle.
+ input_queue.put(inputs, timeout=timeout)
+
+ process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
+ process.start()
+ # Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents
+ # the test to exit properly.
+ try:
+ results = output_queue.get(timeout=timeout)
+ output_queue.task_done()
+ except Exception as e:
+ process.terminate()
+ test_case.fail(e)
+ process.join(timeout=timeout)
+
+ if results["error"] is not None:
+ test_case.fail(f"{results['error']}")
+
+
+def run_test_using_subprocess(func):
+ """
+ To decorate a test to run in a subprocess using the `subprocess` module. This could avoid potential GPU memory
+ issues (GPU OOM or a test that causes many subsequential failing with `CUDA error: device-side assert triggered`).
+ """
+ import pytest
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ if os.getenv("_INSIDE_SUB_PROCESS", None) == "1":
+ func(*args, **kwargs)
+ else:
+ test = " ".join(os.environ.get("PYTEST_CURRENT_TEST").split(" ")[:-1])
+ try:
+ env = copy.deepcopy(os.environ)
+ env["_INSIDE_SUB_PROCESS"] = "1"
+ # This prevents the entries in `short test summary info` given by the subprocess being truncated. so the
+ # full information can be passed to the parent pytest process.
+ # See: https://docs.pytest.org/en/stable/explanation/ci.html
+ env["CI"] = "true"
+
+ # If not subclass of `unitTest.TestCase` and `pytestconfig` is used: try to grab and use the arguments
+ if "pytestconfig" in kwargs:
+ command = list(kwargs["pytestconfig"].invocation_params.args)
+ for idx, x in enumerate(command):
+ if x in kwargs["pytestconfig"].args:
+ test = test.split("::")[1:]
+ command[idx] = "::".join([f"{func.__globals__['__file__']}"] + test)
+ command = [f"{sys.executable}", "-m", "pytest"] + command
+ command = [x for x in command if x != "--no-summary"]
+ # Otherwise, simply run the test with no option at all
+ else:
+ command = [f"{sys.executable}", "-m", "pytest", f"{test}"]
+
+ subprocess.run(command, env=env, check=True, capture_output=True)
+ except subprocess.CalledProcessError as e:
+ exception_message = e.stdout.decode()
+ lines = exception_message.split("\n")
+ # Add a first line with more informative information instead of just `= test session starts =`.
+ # This makes the `short test summary info` section more useful.
+ if "= test session starts =" in lines[0]:
+ text = ""
+ for line in lines[1:]:
+ if line.startswith("FAILED "):
+ text = line[len("FAILED ") :]
+ text = "".join(text.split(" - ")[1:])
+ elif line.startswith("=") and line.endswith("=") and " failed in " in line:
+ break
+ elif len(text) > 0:
+ text += f"\n{line}"
+ text = "(subprocess) " + text
+ lines = [text] + lines
+ exception_message = "\n".join(lines)
+ raise pytest.fail(exception_message, pytrace=False)
+
+ return wrapper
+
+
+"""
+The following contains utils to run the documentation tests without having to overwrite any files.
+
+The `preprocess_string` function adds `# doctest: +IGNORE_RESULT` markers on the fly anywhere a `load_dataset` call is
+made as a print would otherwise fail the corresponding line.
+
+To skip cuda tests, make sure to call `SKIP_CUDA_DOCTEST=1 pytest --doctest-modules
+"""
+
+
+def preprocess_string(string, skip_cuda_tests):
+ """Prepare a docstring or a `.md` file to be run by doctest.
+
+ The argument `string` would be the whole file content if it is a `.md` file. For a python file, it would be one of
+ its docstring. In each case, it may contain multiple python code examples. If `skip_cuda_tests` is `True` and a
+ cuda stuff is detective (with a heuristic), this method will return an empty string so no doctest will be run for
+ `string`.
+ """
+ codeblock_pattern = r"(```(?:python|py)\s*\n\s*>>> )(.*?```)"
+ codeblocks = re.split(codeblock_pattern, string, flags=re.DOTALL)
+ is_cuda_found = False
+ for i, codeblock in enumerate(codeblocks):
+ if "load_dataset(" in codeblock and "# doctest: +IGNORE_RESULT" not in codeblock:
+ codeblocks[i] = re.sub(r"(>>> .*load_dataset\(.*)", r"\1 # doctest: +IGNORE_RESULT", codeblock)
+ if (
+ (">>>" in codeblock or "..." in codeblock)
+ and re.search(r"cuda|to\(0\)|device=0", codeblock)
+ and skip_cuda_tests
+ ):
+ is_cuda_found = True
+ break
+
+ modified_string = ""
+ if not is_cuda_found:
+ modified_string = "".join(codeblocks)
+
+ return modified_string
+
+
+class HfDocTestParser(doctest.DocTestParser):
+ """
+ Overwrites the DocTestParser from doctest to properly parse the codeblocks that are formatted with black. This
+ means that there are no extra lines at the end of our snippets. The `# doctest: +IGNORE_RESULT` marker is also
+ added anywhere a `load_dataset` call is made as a print would otherwise fail the corresponding line.
+
+ Tests involving cuda are skipped base on a naive pattern that should be updated if it is not enough.
+ """
+
+ # This regular expression is used to find doctest examples in a
+ # string. It defines three groups: `source` is the source code
+ # (including leading indentation and prompts); `indent` is the
+ # indentation of the first (PS1) line of the source code; and
+ # `want` is the expected output (including leading indentation).
+ # fmt: off
+ _EXAMPLE_RE = re.compile(r'''
+ # Source consists of a PS1 line followed by zero or more PS2 lines.
+ (?P
+ (?:^(?P [ ]*) >>> .*) # PS1 line
+ (?:\n [ ]* \.\.\. .*)*) # PS2 lines
+ \n?
+ # Want consists of any non-blank lines that do not start with PS1.
+ (?P (?:(?![ ]*$) # Not a blank line
+ (?![ ]*>>>) # Not a line starting with PS1
+ # !!!!!!!!!!! HF Specific !!!!!!!!!!!
+ (?:(?!```).)* # Match any character except '`' until a '```' is found (this is specific to HF because black removes the last line)
+ # !!!!!!!!!!! HF Specific !!!!!!!!!!!
+ (?:\n|$) # Match a new line or end of string
+ )*)
+ ''', re.MULTILINE | re.VERBOSE
+ )
+ # fmt: on
+
+ # !!!!!!!!!!! HF Specific !!!!!!!!!!!
+ skip_cuda_tests: bool = bool(os.environ.get("SKIP_CUDA_DOCTEST", "0"))
+ # !!!!!!!!!!! HF Specific !!!!!!!!!!!
+
+ def parse(self, string, name=""):
+ """
+ Overwrites the `parse` method to incorporate a skip for CUDA tests, and remove logs and dataset prints before
+ calling `super().parse`
+ """
+ string = preprocess_string(string, self.skip_cuda_tests)
+ return super().parse(string, name)
+
+
+class HfDoctestModule(Module):
+ """
+ Overwrites the `DoctestModule` of the pytest package to make sure the HFDocTestParser is used when discovering
+ tests.
+ """
+
+ def collect(self) -> Iterable[DoctestItem]:
+ class MockAwareDocTestFinder(doctest.DocTestFinder):
+ """A hackish doctest finder that overrides stdlib internals to fix a stdlib bug.
+
+ https://github.com/pytest-dev/pytest/issues/3456 https://bugs.python.org/issue25532
+ """
+
+ def _find_lineno(self, obj, source_lines):
+ """Doctest code does not take into account `@property`, this
+ is a hackish way to fix it. https://bugs.python.org/issue17446
+
+ Wrapped Doctests will need to be unwrapped so the correct line number is returned. This will be
+ reported upstream. #8796
+ """
+ if isinstance(obj, property):
+ obj = getattr(obj, "fget", obj)
+
+ if hasattr(obj, "__wrapped__"):
+ # Get the main obj in case of it being wrapped
+ obj = inspect.unwrap(obj)
+
+ # Type ignored because this is a private function.
+ return super()._find_lineno( # type:ignore[misc]
+ obj,
+ source_lines,
+ )
+
+ def _find(self, tests, obj, name, module, source_lines, globs, seen) -> None:
+ if _is_mocked(obj):
+ return
+ with _patch_unwrap_mock_aware():
+ # Type ignored because this is a private function.
+ super()._find( # type:ignore[misc]
+ tests, obj, name, module, source_lines, globs, seen
+ )
+
+ if self.path.name == "conftest.py":
+ module = self.config.pluginmanager._importconftest(
+ self.path,
+ self.config.getoption("importmode"),
+ rootpath=self.config.rootpath,
+ )
+ else:
+ try:
+ module = import_path(
+ self.path,
+ root=self.config.rootpath,
+ mode=self.config.getoption("importmode"),
+ )
+ except ImportError:
+ if self.config.getvalue("doctest_ignore_import_errors"):
+ skip("unable to import module %r" % self.path)
+ else:
+ raise
+
+ # !!!!!!!!!!! HF Specific !!!!!!!!!!!
+ finder = MockAwareDocTestFinder(parser=HfDocTestParser())
+ # !!!!!!!!!!! HF Specific !!!!!!!!!!!
+ optionflags = get_optionflags(self)
+ runner = _get_runner(
+ verbose=False,
+ optionflags=optionflags,
+ checker=_get_checker(),
+ continue_on_failure=_get_continue_on_failure(self.config),
+ )
+ for test in finder.find(module, module.__name__):
+ if test.examples: # skip empty doctests and cuda
+ yield DoctestItem.from_parent(self, name=test.name, runner=runner, dtest=test)
+
+
+def _device_agnostic_dispatch(device: str, dispatch_table: dict[str, Callable], *args, **kwargs):
+ if device not in dispatch_table:
+ if not callable(dispatch_table["default"]):
+ return dispatch_table["default"]
+
+ return dispatch_table["default"](*args, **kwargs)
+
+ fn = dispatch_table[device]
+
+ # Some device agnostic functions return values or None, will return then directly.
+ if not callable(fn):
+ return fn
+
+ return fn(*args, **kwargs)
+
+
+if is_torch_available():
+ # Mappings from device names to callable functions to support device agnostic
+ # testing.
+ BACKEND_MANUAL_SEED = {
+ "cuda": torch.cuda.manual_seed,
+ "cpu": torch.manual_seed,
+ "default": torch.manual_seed,
+ }
+ BACKEND_EMPTY_CACHE = {
+ "cuda": torch.cuda.empty_cache,
+ "cpu": None,
+ "default": None,
+ }
+ BACKEND_DEVICE_COUNT = {
+ "cuda": torch.cuda.device_count,
+ "cpu": lambda: 0,
+ "default": lambda: 1,
+ }
+ BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
+ "cuda": torch.cuda.reset_max_memory_allocated,
+ "cpu": None,
+ "default": None,
+ }
+ BACKEND_MAX_MEMORY_ALLOCATED = {
+ "cuda": torch.cuda.max_memory_allocated,
+ "cpu": 0,
+ "default": 0,
+ }
+ BACKEND_RESET_PEAK_MEMORY_STATS = {
+ "cuda": torch.cuda.reset_peak_memory_stats,
+ "cpu": None,
+ "default": None,
+ }
+ BACKEND_MEMORY_ALLOCATED = {
+ "cuda": torch.cuda.memory_allocated,
+ "cpu": 0,
+ "default": 0,
+ }
+ BACKEND_SYNCHRONIZE = {
+ "cuda": torch.cuda.synchronize,
+ "cpu": None,
+ "default": None,
+ }
+ BACKEND_TORCH_ACCELERATOR_MODULE = {
+ "cuda": torch.cuda,
+ "cpu": None,
+ "default": None,
+ }
+else:
+ BACKEND_MANUAL_SEED = {"default": None}
+ BACKEND_EMPTY_CACHE = {"default": None}
+ BACKEND_DEVICE_COUNT = {"default": lambda: 0}
+ BACKEND_RESET_MAX_MEMORY_ALLOCATED = {"default": None}
+ BACKEND_RESET_PEAK_MEMORY_STATS = {"default": None}
+ BACKEND_MAX_MEMORY_ALLOCATED = {"default": 0}
+ BACKEND_MEMORY_ALLOCATED = {"default": 0}
+ BACKEND_SYNCHRONIZE = {"default": None}
+ BACKEND_TORCH_ACCELERATOR_MODULE = {"default": None}
+
+
+if is_torch_hpu_available():
+ BACKEND_MANUAL_SEED["hpu"] = torch.hpu.manual_seed
+ BACKEND_DEVICE_COUNT["hpu"] = torch.hpu.device_count
+ BACKEND_TORCH_ACCELERATOR_MODULE["hpu"] = torch.hpu
+
+if is_torch_mlu_available():
+ BACKEND_EMPTY_CACHE["mlu"] = torch.mlu.empty_cache
+ BACKEND_MANUAL_SEED["mlu"] = torch.mlu.manual_seed
+ BACKEND_DEVICE_COUNT["mlu"] = torch.mlu.device_count
+ BACKEND_TORCH_ACCELERATOR_MODULE["mlu"] = torch.mlu
+
+if is_torch_npu_available():
+ BACKEND_EMPTY_CACHE["npu"] = torch.npu.empty_cache
+ BACKEND_MANUAL_SEED["npu"] = torch.npu.manual_seed
+ BACKEND_DEVICE_COUNT["npu"] = torch.npu.device_count
+ BACKEND_TORCH_ACCELERATOR_MODULE["npu"] = torch.npu
+
+if is_torch_xpu_available():
+ BACKEND_EMPTY_CACHE["xpu"] = torch.xpu.empty_cache
+ BACKEND_MANUAL_SEED["xpu"] = torch.xpu.manual_seed
+ BACKEND_DEVICE_COUNT["xpu"] = torch.xpu.device_count
+ BACKEND_RESET_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.reset_peak_memory_stats
+ BACKEND_RESET_PEAK_MEMORY_STATS["xpu"] = torch.xpu.reset_peak_memory_stats
+ BACKEND_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.max_memory_allocated
+ BACKEND_MEMORY_ALLOCATED["xpu"] = torch.xpu.memory_allocated
+ BACKEND_SYNCHRONIZE["xpu"] = torch.xpu.synchronize
+ BACKEND_TORCH_ACCELERATOR_MODULE["xpu"] = torch.xpu
+
+
+if is_torch_xla_available():
+ BACKEND_EMPTY_CACHE["xla"] = torch.cuda.empty_cache
+ BACKEND_MANUAL_SEED["xla"] = torch.cuda.manual_seed
+ BACKEND_DEVICE_COUNT["xla"] = torch.cuda.device_count
+
+
+def backend_manual_seed(device: str, seed: int):
+ return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
+
+
+def backend_empty_cache(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
+
+
+def backend_device_count(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
+
+
+def backend_reset_max_memory_allocated(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
+
+
+def backend_reset_peak_memory_stats(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)
+
+
+def backend_max_memory_allocated(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
+
+
+def backend_memory_allocated(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_MEMORY_ALLOCATED)
+
+
+def backend_synchronize(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE)
+
+
+def backend_torch_accelerator_module(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_TORCH_ACCELERATOR_MODULE)
+
+
+if is_torch_available():
+ # If `TRANSFORMERS_TEST_DEVICE_SPEC` is enabled we need to import extra entries
+ # into device to function mappings.
+ if "TRANSFORMERS_TEST_DEVICE_SPEC" in os.environ:
+ device_spec_path = os.environ["TRANSFORMERS_TEST_DEVICE_SPEC"]
+ if not Path(device_spec_path).is_file():
+ raise ValueError(
+ f"Specified path to device spec file is not a file or not found. Received '{device_spec_path}"
+ )
+
+ # Try to strip extension for later import – also verifies we are importing a
+ # python file.
+ device_spec_dir, _ = os.path.split(os.path.realpath(device_spec_path))
+ sys.path.append(device_spec_dir)
+ try:
+ import_name = device_spec_path[: device_spec_path.index(".py")]
+ except ValueError as e:
+ raise ValueError(f"Provided device spec file was not a Python file! Received '{device_spec_path}") from e
+
+ device_spec_module = importlib.import_module(import_name)
+
+ # Imported file must contain `DEVICE_NAME`. If it doesn't, terminate early.
+ try:
+ device_name = device_spec_module.DEVICE_NAME
+ except AttributeError as e:
+ raise AttributeError("Device spec file did not contain `DEVICE_NAME`") from e
+
+ if "TRANSFORMERS_TEST_DEVICE" in os.environ and torch_device != device_name:
+ msg = f"Mismatch between environment variable `TRANSFORMERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n"
+ msg += "Either unset `TRANSFORMERS_TEST_DEVICE` or ensure it matches device spec name."
+ raise ValueError(msg)
+
+ torch_device = device_name
+
+ def update_mapping_from_spec(device_fn_dict: dict[str, Callable], attribute_name: str):
+ try:
+ # Try to import the function directly
+ spec_fn = getattr(device_spec_module, attribute_name)
+ device_fn_dict[torch_device] = spec_fn
+ except AttributeError as e:
+ # If the function doesn't exist, and there is no default, throw an error
+ if "default" not in device_fn_dict:
+ raise AttributeError(
+ f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found."
+ ) from e
+
+ # Add one entry here for each `BACKEND_*` dictionary.
+ update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
+ update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
+ update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
+
+
+def compare_pipeline_output_to_hub_spec(output, hub_spec):
+ missing_keys = []
+ unexpected_keys = []
+ all_field_names = {field.name for field in fields(hub_spec)}
+ matching_keys = sorted([key for key in output if key in all_field_names])
+
+ # Fields with a MISSING default are required and must be in the output
+ for field in fields(hub_spec):
+ if field.default is MISSING and field.name not in output:
+ missing_keys.append(field.name)
+
+ # All output keys must match either a required or optional field in the Hub spec
+ for output_key in output:
+ if output_key not in all_field_names:
+ unexpected_keys.append(output_key)
+
+ if missing_keys or unexpected_keys:
+ error = ["Pipeline output does not match Hub spec!"]
+ if matching_keys:
+ error.append(f"Matching keys: {matching_keys}")
+ if missing_keys:
+ error.append(f"Missing required keys in pipeline output: {missing_keys}")
+ if unexpected_keys:
+ error.append(f"Keys in pipeline output that are not in Hub spec: {unexpected_keys}")
+ raise KeyError("\n".join(error))
+
+
+@require_torch
+def cleanup(device: str, gc_collect=False):
+ if gc_collect:
+ gc.collect()
+ backend_empty_cache(device)
+ torch._dynamo.reset()
+
+
+# Type definition of key used in `Expectations` class.
+DeviceProperties = tuple[Optional[str], Optional[int], Optional[int]]
+# Helper type. Makes creating instances of `Expectations` smoother.
+PackedDeviceProperties = tuple[Optional[str], Union[None, int, tuple[int, int]]]
+
+
+@cache
+def get_device_properties() -> DeviceProperties:
+ """
+ Get environment device properties.
+ """
+ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
+ import torch
+
+ major, minor = torch.cuda.get_device_capability()
+ if IS_ROCM_SYSTEM:
+ return ("rocm", major, minor)
+ else:
+ return ("cuda", major, minor)
+ elif IS_XPU_SYSTEM:
+ import torch
+
+ # To get more info of the architecture meaning and bit allocation, refer to https://github.com/intel/llvm/blob/sycl/sycl/include/sycl/ext/oneapi/experimental/device_architecture.def
+ arch = torch.xpu.get_device_capability()["architecture"]
+ gen_mask = 0x000000FF00000000
+ gen = (arch & gen_mask) >> 32
+ return ("xpu", gen, None)
+ else:
+ return (torch_device, None, None)
+
+
+def unpack_device_properties(
+ properties: Optional[PackedDeviceProperties] = None,
+) -> DeviceProperties:
+ """
+ Unpack a `PackedDeviceProperties` tuple into consistently formatted `DeviceProperties` tuple. If properties is None, it is fetched.
+ """
+ if properties is None:
+ return get_device_properties()
+ device_type, major_minor = properties
+ if major_minor is None:
+ major, minor = None, None
+ elif isinstance(major_minor, int):
+ major, minor = major_minor, None
+ else:
+ major, minor = major_minor
+ return device_type, major, minor
+
+
+class Expectations(UserDict[PackedDeviceProperties, Any]):
+ def get_expectation(self) -> Any:
+ """
+ Find best matching expectation based on environment device properties. We look at device_type, major and minor
+ versions of the drivers. Expectations are stored as a dictionary with keys of the form
+ (device_type, (major, minor)). If the major and minor versions are not provided, we use None.
+ """
+ return self.find_expectation(get_device_properties())
+
+ def unpacked(self) -> list[tuple[DeviceProperties, Any]]:
+ return [(unpack_device_properties(k), v) for k, v in self.data.items()]
+
+ @staticmethod
+ def is_default(expectation_key: PackedDeviceProperties) -> bool:
+ """
+ This function returns True if the expectation_key is the Default expectation (None, None).
+ When an Expectation dict contains a Default value, it is generally because the test existed before Expectations.
+ When we modify a test to use Expectations for a specific hardware, we don't want to affect the tests on other
+ hardwares. Thus we set the previous value as the Default expectation with key (None, None) and add a value for
+ the specific hardware with key (hardware_type, (major, minor)).
+ """
+ return all(p is None for p in expectation_key)
+
+ @staticmethod
+ def score(properties: DeviceProperties, other: DeviceProperties) -> float:
+ """
+ Returns score indicating how similar two instances of the `Properties` tuple are.
+ Rules are as follows:
+ * Matching `type` adds one point, semi-matching `type` adds 0.1 point (e.g. cuda and rocm).
+ * If types match, matching `major` adds another point, and then matching `minor` adds another.
+ * The Default expectation (None, None) is worth 0.5 point, which is better than semi-matching. More on this
+ in the `is_default` function.
+ """
+ device_type, major, minor = properties
+ other_device_type, other_major, other_minor = other
+
+ score = 0
+ # Matching device type, maybe major and minor
+ if device_type is not None and device_type == other_device_type:
+ score += 1
+ if major is not None and major == other_major:
+ score += 1
+ if minor is not None and minor == other_minor:
+ score += 1
+ # Semi-matching device type, which carries less importance than the default expectation
+ elif device_type in ["cuda", "rocm"] and other_device_type in ["cuda", "rocm"]:
+ score = 0.1
+
+ # Default expectation
+ if Expectations.is_default(other):
+ score = 0.5
+
+ return score
+
+ def find_expectation(self, properties: DeviceProperties = (None, None, None)) -> Any:
+ """
+ Find best matching expectation based on provided device properties. We score each expectation, and to
+ distinguish between expectations with the same score, we use the major and minor version numbers, prioritizing
+ most recent versions.
+ """
+ (result_key, result) = max(
+ self.unpacked(),
+ key=lambda x: (
+ Expectations.score(properties, x[0]), # x[0] is a device properties tuple (device_type, major, minor)
+ x[0][1] if x[0][1] is not None else -1, # This key is the major version, -1 if major is None
+ x[0][2] if x[0][2] is not None else -1, # This key is the minor version, -1 if minor is None
+ ),
+ )
+
+ if Expectations.score(properties, result_key) == 0:
+ raise ValueError(f"No matching expectation found for {properties}")
+
+ return result
+
+ def __repr__(self):
+ return f"{self.data}"
+
+
+def patch_torch_compile_force_graph():
+ """
+ Patch `torch.compile` to always use `fullgraph=True`.
+
+ This is useful when some `torch.compile` tests are running with `fullgraph=False` and we want to be able to run
+ them with `fullgraph=True` in some occasion (without introducing new tests) to make sure there is no graph break.
+
+ After PR #40137, `CompileConfig.fullgraph` is `False` by default, this patch is necessary.
+ """
+
+ force_fullgraph = os.environ.get("TORCH_COMPILE_FORCE_FULLGRAPH", "")
+ force_fullgraph = force_fullgraph.lower() in ("yes", "true", "on", "t", "y", "1")
+
+ if force_fullgraph:
+ import torch
+
+ orig_method = torch.compile
+
+ def patched(*args, **kwargs):
+ # In `torch_compile`, all arguments except `model` is keyword only argument.
+ kwargs["fullgraph"] = True
+ return orig_method(*args, **kwargs)
+
+ torch.compile = patched
+
+
+def _get_test_info():
+ """
+ Collect some information about the current test.
+
+ For example, test full name, line number, stack, traceback, etc.
+ """
+
+ full_test_name = os.environ.get("PYTEST_CURRENT_TEST", "").split(" ")[0]
+ test_file, test_class, test_name = full_test_name.split("::")
+
+ # from the most recent frame to the top frame
+ stack_from_inspect = inspect.stack()
+ # but visit from the top frame to the most recent frame
+
+ actual_test_file, _actual_test_class = test_file, test_class
+ test_frame, test_obj, test_method = None, None, None
+ for frame in reversed(stack_from_inspect):
+ # if test_file in str(frame).replace(r"\\", "/"):
+ # check frame's function + if it has `self` as locals; double check if self has the (function) name
+ # TODO: Question: How about expanded?
+ if (
+ frame.function == test_name
+ and "self" in frame.frame.f_locals
+ and hasattr(frame.frame.f_locals["self"], test_name)
+ ):
+ # if test_name == frame.frame.f_locals["self"]._testMethodName:
+ test_frame = frame
+ # The test instance
+ test_obj = frame.frame.f_locals["self"]
+ # TODO: Do we get the (relative?) path or it's just a file name?
+ # TODO: Does `test_obj` always have `tearDown` object?
+ actual_test_file = frame.filename
+ # TODO: check `test_method` will work used at the several places!
+ test_method = getattr(test_obj, test_name)
+ break
+
+ if test_frame is not None:
+ line_number = test_frame.lineno
+
+ # The frame of `patched` being called (the one and the only one calling `_get_test_info`)
+ # This is used to get the original method being patched in order to get the context.
+ frame_of_patched_obj = None
+
+ captured_frames = []
+ to_capture = False
+ # From the most outer (i.e. python's `runpy.py`) frame to most inner frame (i.e. the frame of this method)
+ # Between `the test method being called` and `before entering `patched``.
+ for frame in reversed(stack_from_inspect):
+ if (
+ frame.function == test_name
+ and "self" in frame.frame.f_locals
+ and hasattr(frame.frame.f_locals["self"], test_name)
+ ):
+ to_capture = True
+ # TODO: check simply with the name is not robust.
+ elif "patched" == frame.frame.f_code.co_name:
+ frame_of_patched_obj = frame
+ to_capture = False
+ break
+ if to_capture:
+ captured_frames.append(frame)
+
+ tb_next = None
+ for frame_info in reversed(captured_frames):
+ tb = types.TracebackType(tb_next, frame_info.frame, frame_info.frame.f_lasti, frame_info.frame.f_lineno)
+ tb_next = tb
+ test_traceback = tb
+
+ origin_method_being_patched = frame_of_patched_obj.frame.f_locals["orig_method"]
+
+ # An iterable of type `traceback.StackSummary` with each element of type `FrameSummary`
+ stack = traceback.extract_stack()
+ # The frame which calls `the original method being patched`
+ caller_frame = None
+ # From the most inner (i.e. recent) frame to the most outer frame
+ for frame in reversed(stack):
+ if origin_method_being_patched.__name__ in frame.line:
+ caller_frame = frame
+
+ caller_path = os.path.relpath(caller_frame.filename)
+ caller_lineno = caller_frame.lineno
+
+ test_lineno = line_number
+
+ # Get the code context in the test function/method.
+ from _pytest._code.source import Source
+
+ with open(actual_test_file) as fp:
+ s = fp.read()
+ source = Source(s)
+ test_code_context = "\n".join(source.getstatement(test_lineno - 1).lines)
+
+ # Get the code context in the caller (to the patched function/method).
+ with open(caller_path) as fp:
+ s = fp.read()
+ source = Source(s)
+ caller_code_context = "\n".join(source.getstatement(caller_lineno - 1).lines)
+
+ test_info = f"test:\n\n{full_test_name}\n\n{'-' * 80}\n\ntest context: {actual_test_file}:{test_lineno}\n\n{test_code_context}"
+ test_info = f"{test_info}\n\n{'-' * 80}\n\ncaller context: {caller_path}:{caller_lineno}\n\n{caller_code_context}"
+
+ return (
+ full_test_name,
+ test_file,
+ test_lineno,
+ test_obj,
+ test_method,
+ test_frame,
+ test_traceback,
+ test_code_context,
+ caller_path,
+ caller_lineno,
+ caller_code_context,
+ test_info,
+ )
+
+
+def _get_call_arguments(code_context):
+ """
+ Analyze the positional and keyword arguments in a call expression.
+
+ This will extract the expressions of the positional and kwyword arguments, and associate them to the positions and
+ the keyword arugment names.
+ """
+
+ def get_argument_name(node):
+ """Extract the name/expression from an AST node"""
+ if isinstance(node, ast.Name):
+ return node.id
+ elif isinstance(node, ast.Attribute):
+ return ast.unparse(node)
+ elif isinstance(node, ast.Constant):
+ return repr(node.value)
+ else:
+ return ast.unparse(node)
+
+ indent = len(code_context) - len(code_context.lstrip())
+ code_context = code_context.replace(" " * indent, "")
+
+ try:
+ # Parse the line
+ tree = ast.parse(code_context, mode="eval")
+
+ assert isinstance(tree.body, ast.Call)
+ call_node = tree.body
+
+ if call_node:
+ result = {
+ "positional_args": [],
+ "keyword_args": {},
+ "starargs": None, # *args
+ "kwargs": None, # **kwargs
+ }
+
+ # Extract positional arguments
+ for arg in call_node.args:
+ arg_name = get_argument_name(arg)
+ result["positional_args"].append(arg_name)
+
+ # Extract keyword arguments
+ for keyword in call_node.keywords:
+ if keyword.arg is None:
+ # This is **kwargs
+ result["kwargs"] = get_argument_name(keyword.value)
+ else:
+ # Regular keyword argument
+ arg_name = get_argument_name(keyword.value)
+ result["keyword_args"][keyword.arg] = arg_name
+
+ return result
+
+ except (SyntaxError, AttributeError) as e:
+ print(f"Error parsing: {e}")
+
+ return None
+
+
+def _prepare_debugging_info(test_info, info):
+ """Combine the information about the test and the call information to a patched function/method within it."""
+
+ info = f"{test_info}\n\n{info}"
+ p = os.path.join(os.environ.get("_PATCHED_TESTING_METHODS_OUTPUT_DIR", ""), "captured_info.txt")
+ # TODO (ydshieh): This is not safe when we use pytest-xdist with more than 1 worker.
+ with open(p, "a") as fp:
+ fp.write(f"{info}\n\n{'=' * 120}\n\n")
+
+ return info
+
+
+def _patched_tearDown(self, *args, **kwargs):
+ """Used to report a test that has failures captured and handled by patched functions/methods (without re-raise).
+
+ The patched functions/methods refer to the `patched` defined in `_patch_with_call_info`, which is applied to
+ `torch.testing.assert_close` and `unittest.case.TestCase.assertEqual`.
+
+ The objective is to avoid a failure being silence after being processed.
+
+ If there is any failure that is not handled by the patched functions/methods, we add custom error message for them
+ along with the usual pytest failure report.
+ """
+
+ # Check for regular failures before clearing:
+ # when `_patched_tearDown` is called, the current test fails due to an assertion error given by a method being
+ # patched by `_patch_with_call_info`. The patched method catches such an error and continue running the remaining
+ # statements within the test. If the test fails with another error not handled by the patched methods, we don't let
+ # pytest to fail and report it but the original failure (the first one that was processed) instead.
+ # We still record those failures not handled by the patched methods, and add custom messages along with the usual
+ # pytest failure report.
+ regular_failures_info = []
+ if hasattr(self, "_outcome") and self._outcome.errors:
+ for error_entry in self._outcome.errors:
+ test_instance, (exc_type, exc_obj, exc_tb) = error_entry
+ # breakpoint()
+ regular_failures_info.append(
+ {
+ "message": f"{str(exc_obj)}\n\n",
+ "type": exc_type.__name__,
+ "file": "test_modeling_vit.py",
+ "line": 237, # get_deepest_frame_line(exc_tb) # Your helper function
+ }
+ )
+
+ # Clear the regular failure (i.e. that is not from any of our patched assertion methods) from pytest's records.
+ self._outcome.errors.clear()
+
+ # reset back to the original tearDown method, so `_patched_tearDown` won't be run by the subsequent tests if they
+ # have only test failures that are not handle by the patched methods (or no test failure at all).
+ orig_tearDown = _patched_tearDown.orig_tearDown
+ type(self).tearDown = orig_tearDown
+
+ # Call the original tearDown
+ orig_tearDown(self, *args, **kwargs)
+
+ # Get the failure
+ test_method = getattr(self, self._testMethodName)
+ captured_failures = test_method.__func__.captured_failures[id(test_method)]
+
+ # TODO: How could we show several exceptions in a sinigle test on the terminal? (Maybe not a good idea)
+ captured_exceptions = captured_failures[0]["exception"]
+ captured_traceback = captured_failures[0]["traceback"]
+ # Show the cpatured information on the terminal.
+ capturued_info = [x["info"] for x in captured_failures]
+ capturued_info_str = f"\n\n{'=' * 80}\n\n".join(capturued_info)
+
+ # Enhance the exception message if there were suppressed failures
+ if regular_failures_info:
+ enhanced_message = f"""{str(captured_exceptions)}
+
+{"=" * 80}
+Handled Failures: ({len(capturued_info)} handled):
+{"-" * 80}\n
+{capturued_info_str}
+
+{"=" * 80}
+Unhandled Failures: ({len(regular_failures_info)} unhandled):
+{"-" * 80}\n
+{", ".join(f"{info['type']}: {info['message']}{info['file']}:{info['line']}" for info in regular_failures_info)}
+
+{"-" * 80}
+Note: This failure occurred after other failures analyzed by the patched assertion methods.
+To see the full details, temporarily disable assertion patching.
+{"=" * 80}"""
+
+ # Create new exception with enhanced message
+ enhanced_exception = type(captured_exceptions)(enhanced_message)
+ enhanced_exception.__cause__ = captured_exceptions.__cause__
+ enhanced_exception.__context__ = captured_exceptions.__context__
+
+ # Raise with your existing traceback reconstruction
+ captured_exceptions = enhanced_exception
+
+ # clean up the recorded status
+ del test_method.__func__.captured_failures
+
+ raise captured_exceptions.with_traceback(captured_traceback)
+
+
+def _patch_with_call_info(module_or_class, attr_name, _parse_call_info_func, target_args):
+ """
+ Patch a callerable `attr_name` of a module or class `module_or_class`.
+
+ This will allow us to collect the call information, e.g. the argument names and values, also the literal expressions
+ passed as the arguments.
+ """
+ orig_method = getattr(module_or_class, attr_name)
+ if not callable(orig_method):
+ return
+
+ def patched(*args, **kwargs):
+ # If the target callable is not called within a test, simply call it without modification.
+ if not os.environ.get("PYTEST_CURRENT_TEST", ""):
+ return orig_method(*args, **kwargs)
+
+ try:
+ orig_method(*args, **kwargs)
+ except AssertionError as e:
+ captured_exception = e
+ # captured_traceback = e.__traceback__
+ (
+ full_test_name,
+ test_file,
+ test_lineno,
+ test_obj,
+ test_method,
+ test_frame,
+ test_traceback,
+ test_code_context,
+ caller_path,
+ caller_lineno,
+ caller_code_context,
+ test_info,
+ ) = _get_test_info()
+ test_info = f"{test_info}\n\n{'-' * 80}\n\npatched method: {orig_method.__module__}.{orig_method.__name__}"
+ call_argument_expressions = _get_call_arguments(caller_code_context)
+
+ # This is specific
+ info = _parse_call_info_func(orig_method, args, kwargs, call_argument_expressions, target_args)
+ info = _prepare_debugging_info(test_info, info)
+
+ # If the test is running in a CI environment (e.g. not a manual run), let's raise and fail the test, so it
+ # behaves as usual.
+ # On Github Actions or CircleCI, this is set automatically.
+ # When running manually, it's the user to determine if to set it.
+ # This is to avoid the patched function being called `with self.assertRaises(AssertionError):` and fails
+ # because of the missing expected `AssertionError`.
+ # TODO (ydshieh): If there is way to raise only when we are inside such context managers?
+ # TODO (ydshieh): How not to record the failure if it happens inside `self.assertRaises(AssertionError)`?
+ if os.getenv("CI") == "true":
+ raise captured_exception.with_traceback(test_traceback)
+
+ # Save this, so we can raise at the end of the current test
+ captured_failure = {
+ "result": "failed",
+ "exception": captured_exception,
+ "traceback": test_traceback,
+ "info": info,
+ }
+
+ # Record the failure status and its information, so we can raise it later.
+ # We are modifying the (unbound) function at class level: not its logic but only adding a new extra
+ # attribute.
+ if getattr(test_method.__func__, "captured_failures", None) is None:
+ test_method.__func__.captured_failures = {}
+ if id(test_method) not in test_method.__func__.captured_failures:
+ test_method.__func__.captured_failures[id(test_method)] = []
+ test_method.__func__.captured_failures[id(test_method)].append(captured_failure)
+
+ # This modifies the `tearDown` which will be called after every tests, but we reset it back inside
+ # `_patched_tearDown`.
+ if not hasattr(type(test_obj).tearDown, "orig_tearDown"):
+ orig_tearDown = type(test_obj).tearDown
+ _patched_tearDown.orig_tearDown = orig_tearDown
+ type(test_obj).tearDown = _patched_tearDown
+
+ setattr(module_or_class, attr_name, patched)
+
+
+def _parse_call_info(func, args, kwargs, call_argument_expressions, target_args):
+ """
+ Prepare a string containing the call info to `func`, e.g. argument names/values/expressions.
+ """
+ signature = inspect.signature(func)
+ signature_names = [param.name for param_name, param in signature.parameters.items()]
+
+ # called as `self.method_name()` or `xxx.method_name()`.
+ if len(args) == len(call_argument_expressions["positional_args"]) + 1:
+ # We simply add "self" as the expression despite it might not be the actual argument name.
+ # (This part is very unlikely what a user would be interest to know)
+ call_argument_expressions["positional_args"] = ["self"] + call_argument_expressions["positional_args"]
+
+ param_position_mapping = {param_name: idx for idx, param_name in enumerate(signature_names)}
+
+ arg_info = {}
+ for arg_name in target_args:
+ if arg_name in kwargs:
+ arg_value = kwargs[arg_name]
+ arg_expr = call_argument_expressions["keyword_args"][arg_name]
+ else:
+ arg_pos = param_position_mapping[arg_name]
+ arg_value = args[arg_pos]
+ arg_expr = call_argument_expressions["positional_args"][arg_pos]
+
+ arg_value_str = _format_py_obj(arg_value)
+ arg_info[arg_name] = {"arg_expr": arg_expr, "arg_value_str": arg_value_str}
+
+ info = ""
+ for arg_name in arg_info:
+ arg_expr, arg_value_str = arg_info[arg_name]["arg_expr"], arg_info[arg_name]["arg_value_str"]
+ info += f"{'-' * 80}\n\nargument name: `{arg_name}`\nargument expression: `{arg_expr}`\n\nargument value:\n\n{arg_value_str}\n\n"
+
+ # remove the trailing \n\n
+ info = info[:-2]
+
+ return info
+
+
+def patch_testing_methods_to_collect_info():
+ """
+ Patch some methods (`torch.testing.assert_close`, `unittest.case.TestCase.assertEqual`, etc).
+
+ This will allow us to collect the call information, e.g. the argument names and values, also the literal expressions
+ passed as the arguments.
+ """
+ p = os.path.join(os.environ.get("_PATCHED_TESTING_METHODS_OUTPUT_DIR", ""), "captured_info.txt")
+ Path(p).unlink(missing_ok=True)
+
+ if is_torch_available():
+ import torch
+
+ _patch_with_call_info(torch.testing, "assert_close", _parse_call_info, target_args=("actual", "expected"))
+
+ _patch_with_call_info(unittest.case.TestCase, "assertEqual", _parse_call_info, target_args=("first", "second"))
+ _patch_with_call_info(unittest.case.TestCase, "assertListEqual", _parse_call_info, target_args=("list1", "list2"))
+ _patch_with_call_info(
+ unittest.case.TestCase, "assertTupleEqual", _parse_call_info, target_args=("tuple1", "tuple2")
+ )
+ _patch_with_call_info(unittest.case.TestCase, "assertSetEqual", _parse_call_info, target_args=("set1", "set1"))
+ _patch_with_call_info(unittest.case.TestCase, "assertDictEqual", _parse_call_info, target_args=("d1", "d2"))
+ _patch_with_call_info(unittest.case.TestCase, "assertIn", _parse_call_info, target_args=("member", "container"))
+ _patch_with_call_info(unittest.case.TestCase, "assertNotIn", _parse_call_info, target_args=("member", "container"))
+ _patch_with_call_info(unittest.case.TestCase, "assertLess", _parse_call_info, target_args=("a", "b"))
+ _patch_with_call_info(unittest.case.TestCase, "assertLessEqual", _parse_call_info, target_args=("a", "b"))
+ _patch_with_call_info(unittest.case.TestCase, "assertGreater", _parse_call_info, target_args=("a", "b"))
+ _patch_with_call_info(unittest.case.TestCase, "assertGreaterEqual", _parse_call_info, target_args=("a", "b"))
+
+
+def torchrun(script: str, nproc_per_node: int, is_torchrun: bool = True, env: Optional[dict] = None):
+ """Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary."""
+ with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
+ tmp.write(script)
+ tmp.flush()
+ tmp.seek(0)
+ if is_torchrun:
+ cmd = (
+ f"torchrun --nproc_per_node {nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}"
+ ).split()
+ else:
+ cmd = ["python3", tmp.name]
+
+ # Note that the subprocess will be waited for here, and raise an error if not successful
+ try:
+ _ = subprocess.run(cmd, capture_output=True, env=env, text=True, check=True)
+ except subprocess.CalledProcessError as e:
+ raise Exception(f"The following error was captured: {e.stderr}")
+
+
+def _format_tensor(t, indent_level=0, sci_mode=None):
+ """Format torch's tensor in a pretty way to be shown 👀 in the test report."""
+
+ # `torch.testing.assert_close` could accept python int/float numbers.
+ if not isinstance(t, torch.Tensor):
+ t = torch.tensor(t)
+
+ # Simply make the processing below simpler (not to hande both case)
+ is_scalar = False
+ if t.ndim == 0:
+ t = torch.tensor([t])
+ is_scalar = True
+
+ # For scalar or one-dimensional tensor, keep it as one-line. If there is only one element along any dimension except
+ # the last one, we also keep it as one-line.
+ if t.ndim <= 1 or set(t.shape[0:-1]) == {1}:
+ # Use `detach` to remove `grad_fn=<...>`, and use `to("cpu")` to remove `device='...'`
+ t = t.detach().to("cpu")
+
+ # We work directly with the string representation instead the tensor itself
+ t_str = str(t)
+
+ # remove `tensor( ... )` so keep only the content
+ t_str = t_str.replace("tensor(", "").replace(")", "")
+
+ # Sometimes there are extra spaces between `[` and the first digit of the first value (for alignment).
+ # For example `[[ 0.06, -0.51], [-0.76, -0.49]]`. It may have multiple consecutive spaces.
+ # Let's remove such extra spaces.
+ while "[ " in t_str:
+ t_str = t_str.replace("[ ", "[")
+
+ # Put everything in a single line. We replace `\n` by a space ` ` so we still keep `,\n` as `, `.
+ t_str = t_str.replace("\n", " ")
+
+ # Remove repeated spaces (introduced by the previous step)
+ while " " in t_str:
+ t_str = t_str.replace(" ", " ")
+
+ # remove leading `[` and `]` for scalar tensor
+ if is_scalar:
+ t_str = t_str[1:-1]
+
+ t_str = " " * 4 * indent_level + t_str
+
+ return t_str
+
+ # Otherwise, we separte the representations of every elements along an outer dimension by new lines (after a `,`).
+ # The representatioin each element is obtained by calling this function recursively with corrent `indent_level`.
+ else:
+ t_str = str(t)
+
+ # (For the recursive calls should receive this value)
+ if sci_mode is None:
+ sci_mode = "e+" in t_str or "e-" in t_str
+
+ # Use the original content to determine the scientific mode to use. This is required as the representation of
+ # t[index] (computed below) maybe have different format regarding scientific notation.
+ torch.set_printoptions(sci_mode=sci_mode)
+
+ t_str = " " * 4 * indent_level + "[\n"
+ # Keep the ending `,` for all outer dimensions whose representations are not put in one-line, even if there is
+ # only one element along that dimension.
+ t_str += ",\n".join(_format_tensor(x, indent_level=indent_level + 1, sci_mode=sci_mode) for x in t)
+ t_str += ",\n" + " " * 4 * indent_level + "]"
+
+ torch.set_printoptions(sci_mode=None)
+
+ return t_str
+
+
+def _quote_string(s):
+ """Given a string `s`, return a python literal expression that give `s` when it is used in a python source code.
+
+ For example, if `s` is the string `abc`, the return value is `"abc"`.
+
+ We choice double quotes over single quote despite `str(s)` would give `'abc'` instead of `"abc"`.
+ """
+ has_single_quote = "'" in s
+ has_double_quote = '"' in s
+
+ if has_single_quote and has_double_quote:
+ # replace any double quote by the raw string r'\"'.
+ s = s.replace('"', r"\"")
+ return f'"{s}"'
+ elif has_single_quote:
+ return f'"{s}"'
+ elif has_double_quote:
+ return f"'{s}'"
+ else:
+ return f'"{s}"'
+
+
+def _format_py_obj(obj, indent=0, mode="", cache=None, prefix=""):
+ """Format python objects of basic built-in type in a pretty way so we could copy-past them to code editor easily.
+
+ Currently, this support int, float, str, list, tuple, and dict.
+
+ It also works with `torch.Tensor` via calling `format_tesnor`.
+ """
+
+ if cache is None:
+ cache = {}
+ else:
+ if (id(obj), indent, mode, prefix) in cache:
+ return cache[(id(obj), indent, mode, prefix)]
+
+ # special format method for `torch.Tensor`
+ if str(obj.__class__) == "":
+ return _format_tensor(obj)
+
+ elif obj.__class__.__name__ == "str":
+ quoted_string = _quote_string(obj)
+ # we don't want the newline being interpreted
+ quoted_string = quoted_string.replace("\n", r"\n")
+ output = quoted_string
+
+ elif obj.__class__.__name__ in ["int", "float"]:
+ # for float like `1/3`, we will get `0.3333333333333333`
+ output = str(obj)
+
+ elif obj.__class__.__name__ in ["list", "tuple", "dict"]:
+ parenthesis = {
+ "list": "[]",
+ "tuple": "()",
+ "dict": "{}",
+ }
+ p1, p2 = parenthesis[obj.__class__.__name__]
+
+ elements_without_indent = []
+ if isinstance(obj, dict):
+ for idx, (k, v) in enumerate(obj.items()):
+ last_element = idx == len(obj) - 1
+ ok = _format_py_obj(k, indent=indent + 1, mode="one-line", cache=cache)
+ ov = _format_py_obj(
+ v,
+ indent=indent + 1,
+ mode=mode,
+ cache=cache,
+ prefix=ok.lstrip() + ": " + "," if not last_element else "",
+ )
+ # Each element could be multiple-line, but the indent of its first line is removed
+ elements_without_indent.append(f"{ok.lstrip()}: {ov.lstrip()}")
+
+ else:
+ for idx, x in enumerate(obj):
+ last_element = idx == len(obj) - 1
+ o = _format_py_obj(
+ x, indent=indent + 1, mode=mode, cache=cache, prefix="," if not last_element else ""
+ )
+ # Each element could be multiple-line, but the indent of its first line is removed
+ elements_without_indent.append(o.lstrip())
+
+ groups = []
+ buf = []
+ for idx, x in enumerate(elements_without_indent):
+ buf.append(x)
+
+ x_expanded = "\n" in buf[-1]
+ not_last_element = idx != len(elements_without_indent) - 1
+ # if `x` should be separated from subsequent elements
+ should_finalize_x = x_expanded or len(f"{' ' * (4 * (indent + 1))}") + len(
+ ", ".join(buf[-1:])
+ ) > 120 - int(not_last_element)
+
+ # if `buf[:-1]` (i.e. without `x`) should be combined together (into one line)
+ should_finalize_buf = x_expanded
+
+ # the recursive call returns single line, so we can use it to determine if we can fit the width limit
+ if not should_finalize_buf:
+ buf_not_fit_into_one_line = len(f"{' ' * (4 * (indent + 1))}") + len(", ".join(buf)) > 120 - int(
+ not_last_element
+ )
+ should_finalize_buf = buf_not_fit_into_one_line
+
+ # any element of iterable type need to be on its own line
+ if (type(obj[idx]) if type(obj) is not dict else type(list(obj.values())[idx])) in [list, tuple, dict]:
+ should_finalize_x = True
+ should_finalize_buf = True
+
+ # any type change --> need to be added after a new line
+ prev_type = None
+ current_type = type(obj[idx]) if type(obj) is not dict else type(list(obj.values())[idx])
+ if len(buf) > 1:
+ prev_type = type(obj[idx - 1]) if type(obj) is not dict else type(list(obj.values())[idx - 1])
+ type_changed = current_type != prev_type
+ if type_changed:
+ should_finalize_buf = True
+
+ # all elements in the buf are string --> don't finalize the buf by width limit
+ if prev_type is None or (prev_type is str and current_type is str):
+ should_finalize_buf = False
+
+ # collect as many elements of string type as possible (without width limit).
+ # These will be examined as a whole (if not fit into the width, each element would be in its own line)
+ if current_type is str:
+ should_finalize_x = False
+ # `len(buf) == 1` or `obj[idx-1]` is a string
+ if prev_type in [None, str]:
+ should_finalize_buf = False
+
+ if should_finalize_buf:
+ orig_buf_len = len(buf)
+
+ if orig_buf_len > 1:
+ not_fit_into_one_line = None
+
+ # all elements in `obj` that give `buf[:-1]` are string.
+ if prev_type is str:
+ # `-1` at the end: because buf[-2] is not the last element
+ not_fit_into_one_line = len(f"{' ' * (4 * (indent + 1))}") + len(", ".join(buf[:-1])) > 120 - 1
+
+ if not_fit_into_one_line:
+ for x in buf[:-1]:
+ groups.append([x])
+ else:
+ groups.append(buf[:-1])
+
+ buf = buf[-1:]
+
+ if should_finalize_x:
+ groups.append(buf)
+ buf = []
+
+ # The last buf
+ if len(buf) > 0:
+ not_fit_into_one_line = None
+ if current_type is str:
+ # no `-1` at the end: because buf[-1] is the last element
+ not_fit_into_one_line = len(f"{' ' * (4 * (indent + 1))}") + len(", ".join(buf)) > 120
+
+ if not_fit_into_one_line:
+ for x in buf:
+ groups.append([x])
+ else:
+ groups.append(buf)
+
+ output = f"{' ' * 4 * indent}{p1}\n"
+ element_strings = [f"{' ' * (4 * (indent + 1))}" + ", ".join(buf) for buf in groups]
+ output += ",\n".join(element_strings)
+ output += f"\n{' ' * 4 * indent}{p2}"
+
+ # if all elements are in one-line
+ no_new_line_in_elements = all("\n" not in x for x in element_strings)
+ # if yes, we can form a one-line representation of `obj`
+ could_use_one_line = no_new_line_in_elements
+
+ # if mode == "one-line", this function always returns one-line representation, so `no_new_line_in_elements`
+ # will be `True`.
+ if could_use_one_line:
+ one_line_form = ", ".join([x.lstrip() for x in element_strings])
+ one_line_form = f"{p1}{one_line_form}{p2}"
+
+ if mode == "one-line":
+ return output
+
+ # check with the width limit
+ could_use_one_line = len(f"{' ' * 4 * indent}") + len(prefix) + len(one_line_form) <= 120
+
+ # extra conditions for returning one-line representation
+ def use_one_line_repr(obj):
+ # interable types
+ if type(obj) in (list, tuple, dict):
+ # get all types
+ element_types = []
+ if type(obj) is dict:
+ element_types.extend(type(x) for x in obj.values())
+ elif type(obj) in [list, tuple]:
+ element_types.extend(type(x) for x in obj)
+
+ # At least one element is of iterable type
+ if any(x in (list, tuple, dict) for x in element_types):
+ # If `obj` has more than one element and at least one of them is iterable --> no one line repr.
+ if len(obj) > 1:
+ return False
+
+ # only one element that is iterable, but not the same type as `obj` --> no one line repr.
+ if type(obj) is not type(obj[0]):
+ return False
+
+ # one-line repr. if possible, without width limit
+ return no_new_line_in_elements
+
+ # all elements are of simple types, but more than one type --> no one line repr.
+ if len(set(element_types)) > 1:
+ return False
+
+ # all elements are of the same simple type
+ if element_types[0] in [int, float]:
+ # one-line repr. without width limit
+ return no_new_line_in_elements
+ elif element_types[0] is str:
+ if len(obj) == 1:
+ # one single string element --> one-line repr. without width limit
+ return no_new_line_in_elements
+ else:
+ # multiple string elements --> one-line repr. if fit into width limit
+ return could_use_one_line
+
+ # simple types (int, flat, string)
+ return True
+
+ # width condition combined with specific mode conditions
+ if use_one_line_repr(obj):
+ output = f"{' ' * 4 * indent}{one_line_form}"
+
+ cache[(id(obj), indent, mode, prefix)] = output
+
+ return output
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tf_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tf_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..11d07f8d7edab7401bd4582b57a095db6552475a
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tf_utils.py
@@ -0,0 +1,294 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import numpy as np
+import tensorflow as tf
+
+from .feature_extraction_utils import BatchFeature
+from .tokenization_utils_base import BatchEncoding
+from .utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> list[int]:
+ """
+ Deal with dynamic shape in tensorflow cleanly.
+
+ Args:
+ tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of.
+
+ Returns:
+ `list[int]`: The shape of the tensor as a list.
+ """
+ if isinstance(tensor, np.ndarray):
+ return list(tensor.shape)
+
+ dynamic = tf.shape(tensor)
+
+ if tensor.shape == tf.TensorShape(None):
+ return dynamic
+
+ static = tensor.shape.as_list()
+
+ return [dynamic[i] if s is None else s for i, s in enumerate(static)]
+
+
+def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional[str] = None) -> tf.Tensor:
+ """
+ Stable wrapper that returns the same output as `tf.nn.softmax`, but that works reliably with XLA on CPU. It is
+ meant as a workaround for the [following issue](https://github.com/tensorflow/tensorflow/issues/55682), and will be
+ removed after it gets fixed. The arguments and outputs are the same as `tf.nn.softmax`, and relies on the fact that
+ `softmax(x) = softmax(x + c)` (see https://ogunlao.github.io/2020/04/26/you_dont_really_know_softmax.html).
+
+ Args:
+ logits (`tf.Tensor`):
+ Must be one of the following types: half, float32, float64.
+ axis (`int`, *optional*):
+ The dimension softmax would be performed on. The default is -1 which indicates the last dimension.
+ name (`str`, *optional*):
+ A name for the operation.
+
+ Returns:
+ `tf.Tensor`:
+ A Tensor. Has the same type and shape as logits.
+ """
+ # TODO: When the issue linked above gets sorted, add a check on TF version here and use the original function if
+ # it has the fix. After we drop the support for unfixed versions, remove this function.
+ return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)
+
+
+def functional_layernorm(inputs, weight, bias, epsilon=1e-5, axis=-1):
+ # This is a very simplified functional layernorm, designed to duplicate
+ # the functionality of PyTorch nn.functional.layer_norm when this is needed to port
+ # models in Transformers.
+
+ if weight.shape.rank != 1 or bias.shape.rank != 1 or not isinstance(axis, int):
+ raise NotImplementedError("Only 1D weight and bias tensors are supported for now, with only a single axis.")
+
+ # Get mean and variance on the axis to be normalized
+ mean, variance = tf.nn.moments(inputs, axes=[axis], keepdims=True)
+
+ if axis != -1:
+ # Reshape scale and weight to have the same rank as inputs, but with 1 dimensions
+ # on every dimension except axis
+ shape = [1] * inputs.shape.rank
+ shape[axis] = shape_list(inputs)[axis]
+ weight = tf.reshape(weight, shape)
+ bias = tf.reshape(bias, shape)
+
+ # Compute layer normalization using the batch_normalization
+ # function.
+ outputs = tf.nn.batch_normalization(
+ inputs,
+ mean,
+ variance,
+ offset=bias,
+ scale=weight,
+ variance_epsilon=epsilon,
+ )
+ return outputs
+
+
+def scaled_dot_product_attention(
+ query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale: Optional[float] = None
+):
+ """TF equivalent for torch's nn.functional.scaled_dot_product_attention"""
+ if dropout_p != 0.0:
+ raise ValueError(
+ "Dropout is not supported in this implementation - file an issue "
+ "with Transformers and ping @Rocketknight1 if you need it for a port!"
+ )
+ if is_causal and attn_mask is not None:
+ raise ValueError("You cannot specify an attn_mask and is_causal at the same time!")
+ if is_causal:
+ attn_mask = tf.ones((tf.shape(query)[-2], tf.shape(key)[-2]), dtype=tf.int32)
+ attn_mask = tf.experimental.numpy.tril(attn_mask, k=0)
+ if attn_mask is not None and (attn_mask.dtype.is_integer or attn_mask.dtype.is_bool):
+ # Convert boolean mask to a negative logit bias
+ attn_mask = tf.where(attn_mask > 0, tf.cast(0.0, query.dtype), tf.cast(-1000.0, query.dtype))
+ logits = tf.einsum("...qd, ...kd -> ...qk", query, key)
+ if scale is None:
+ scale = tf.cast(tf.shape(key)[-1], logits.dtype) ** -0.5
+ logits *= scale # scale by 1/sqrt(key_dim)
+ if attn_mask is not None:
+ logits += attn_mask
+ probs = tf.nn.softmax(logits)
+ return probs @ value
+
+
+def flatten(input, start_dim=0, end_dim=-1):
+ # Replicates the behavior of torch.flatten in TF
+
+ # If end_dim or start_dim is negative, count them from the end
+ if end_dim < 0:
+ end_dim += input.shape.rank
+ if start_dim < 0:
+ start_dim += input.shape.rank
+
+ if start_dim == end_dim:
+ return input
+
+ in_shape = tf.shape(input)
+ flattened_dim = tf.math.reduce_prod(in_shape[start_dim : end_dim + 1])
+ out_shape = tf.concat([in_shape[:start_dim], [flattened_dim], in_shape[end_dim + 1 :]], axis=0)
+ return tf.reshape(input, out_shape)
+
+
+def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor:
+ """
+ Invert an attention mask (e.g., switches 0. and 1.).
+
+ Args:
+ encoder_attention_mask (`torch.Tensor`): An attention mask.
+
+ Returns:
+ `tf.Tensor`: The inverted attention mask.
+ """
+ if not isinstance(encoder_attention_mask, tf.Tensor):
+ encoder_attention_mask = tf.convert_to_tensor(encoder_attention_mask) # Catches stray NumPy inputs
+ if encoder_attention_mask.shape.rank == 3:
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
+ if encoder_attention_mask.shape.rank == 2:
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
+ # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
+ # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
+ # /transformer/transformer_layers.py#L270
+ # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
+ # encoder_extended_attention_mask.transpose(-1, -2))
+ encoder_extended_attention_mask = (
+ tf.cast(1, encoder_attention_mask.dtype) - encoder_extended_attention_mask
+ ) * encoder_extended_attention_mask.dtype.min
+
+ return encoder_extended_attention_mask
+
+
+def check_embeddings_within_bounds(tensor: tf.Tensor, embed_dim: int, tensor_name: str = "input_ids") -> None:
+ """
+ `tf.gather`, on which TF embedding layers are based, won't check positive out of bound indices on GPU, returning
+ zeros instead. This function adds a check against that dangerous silent behavior.
+
+ Args:
+ tensor (`tf.Tensor`): The tensor of indices to check.
+ embed_dim (`int`): The embedding dimension.
+ tensor_name (`str`, *optional*): The name of the tensor to use in the error message.
+ """
+ tf.debugging.assert_less(
+ tensor,
+ tf.cast(embed_dim, dtype=tensor.dtype),
+ message=(
+ f"The maximum value of {tensor_name} ({tf.math.reduce_max(tensor)}) must be smaller than the embedding "
+ f"layer's input dimension ({embed_dim}). The likely cause is some problem at tokenization time."
+ ),
+ )
+
+
+def save_attributes_to_hdf5_group(group, name, data):
+ """Saves attributes (data) of the specified name into the HDF5 group.
+
+ This method deals with an inherent problem of HDF5 file which is not able to store data larger than
+ HDF5_OBJECT_HEADER_LIMIT bytes.
+
+ Args:
+ group: A pointer to a HDF5 group.
+ name: A name of the attributes to save.
+ data: Attributes data to store.
+
+ Raises:
+ RuntimeError: If any single attribute is too large to be saved.
+
+ Copied from Keras to Transformers to avoid versioning issues.
+ """
+ HDF5_OBJECT_HEADER_LIMIT = 64512
+ # Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT`
+ # because in that case even chunking the array would not make the saving
+ # possible.
+ bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT]
+
+ # Expecting this to never be true.
+ if bad_attributes:
+ raise RuntimeError(
+ "The following attributes cannot be saved to HDF5 file because "
+ f"they are larger than {HDF5_OBJECT_HEADER_LIMIT} "
+ f"bytes: {bad_attributes}"
+ )
+
+ data_npy = np.asarray(data)
+
+ num_chunks = 1
+ chunked_data = np.array_split(data_npy, num_chunks)
+
+ # This will never loop forever thanks to the test above.
+ while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data):
+ num_chunks += 1
+ chunked_data = np.array_split(data_npy, num_chunks)
+
+ if num_chunks > 1:
+ for chunk_id, chunk_data in enumerate(chunked_data):
+ group.attrs["%s%d" % (name, chunk_id)] = chunk_data
+ else:
+ group.attrs[name] = data
+
+
+def load_attributes_from_hdf5_group(group, name):
+ """Loads attributes of the specified name from the HDF5 group.
+
+ This method deals with an inherent problem of HDF5 file which is not able to store data larger than
+ HDF5_OBJECT_HEADER_LIMIT bytes.
+
+ Args:
+ group: A pointer to a HDF5 group.
+ name: A name of the attributes to load.
+
+ Returns:
+ data: Attributes data.
+
+ Copied from Keras to Transformers to avoid versioning issues.
+ """
+ if name in group.attrs:
+ data = [n.decode("utf8") if hasattr(n, "decode") else n for n in group.attrs[name]]
+ else:
+ data = []
+ chunk_id = 0
+ while "%s%d" % (name, chunk_id) in group.attrs:
+ data.extend(
+ [n.decode("utf8") if hasattr(n, "decode") else n for n in group.attrs["%s%d" % (name, chunk_id)]]
+ )
+ chunk_id += 1
+ return data
+
+
+def expand_1d(data):
+ """Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s.
+ Copied from Keras to here to avoid versioning issues."""
+
+ def _expand_single_1d_tensor(t):
+ if isinstance(t, tf.Tensor) and t.shape.rank == 1:
+ return tf.expand_dims(t, axis=-1)
+ return t
+
+ return tf.nest.map_structure(_expand_single_1d_tensor, data)
+
+
+def convert_batch_encoding(*args, **kwargs):
+ # Convert HF BatchEncoding/BatchFeature objects in the inputs to dicts that Keras understands
+ if args and isinstance(args[0], (BatchEncoding, BatchFeature)):
+ args = list(args)
+ args[0] = dict(args[0])
+ elif "x" in kwargs and isinstance(kwargs["x"], (BatchEncoding, BatchFeature)):
+ kwargs["x"] = dict(kwargs["x"])
+ return args, kwargs
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/time_series_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/time_series_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a5cf4f2f4d8f636e623c07ae3400ca5e17b5891
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/time_series_utils.py
@@ -0,0 +1,225 @@
+# Copyright 2023 The HuggingFace Inc. team.
+# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Time series distributional output classes and utilities.
+"""
+
+from typing import Callable, Optional
+
+import torch
+from torch import nn
+from torch.distributions import (
+ AffineTransform,
+ Distribution,
+ Independent,
+ NegativeBinomial,
+ Normal,
+ StudentT,
+ TransformedDistribution,
+)
+
+
+class AffineTransformed(TransformedDistribution):
+ def __init__(self, base_distribution: Distribution, loc=None, scale=None, event_dim=0):
+ self.scale = 1.0 if scale is None else scale
+ self.loc = 0.0 if loc is None else loc
+
+ super().__init__(base_distribution, [AffineTransform(loc=self.loc, scale=self.scale, event_dim=event_dim)])
+
+ @property
+ def mean(self):
+ """
+ Returns the mean of the distribution.
+ """
+ return self.base_dist.mean * self.scale + self.loc
+
+ @property
+ def variance(self):
+ """
+ Returns the variance of the distribution.
+ """
+ return self.base_dist.variance * self.scale**2
+
+ @property
+ def stddev(self):
+ """
+ Returns the standard deviation of the distribution.
+ """
+ return self.variance.sqrt()
+
+
+class ParameterProjection(nn.Module):
+ def __init__(
+ self, in_features: int, args_dim: dict[str, int], domain_map: Callable[..., tuple[torch.Tensor]], **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.args_dim = args_dim
+ self.proj = nn.ModuleList([nn.Linear(in_features, dim) for dim in args_dim.values()])
+ self.domain_map = domain_map
+
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]:
+ params_unbounded = [proj(x) for proj in self.proj]
+
+ return self.domain_map(*params_unbounded)
+
+
+class LambdaLayer(nn.Module):
+ def __init__(self, function):
+ super().__init__()
+ self.function = function
+
+ def forward(self, x, *args):
+ return self.function(x, *args)
+
+
+class DistributionOutput:
+ distribution_class: type
+ in_features: int
+ args_dim: dict[str, int]
+
+ def __init__(self, dim: int = 1) -> None:
+ self.dim = dim
+ self.args_dim = {k: dim * self.args_dim[k] for k in self.args_dim}
+
+ def _base_distribution(self, distr_args):
+ if self.dim == 1:
+ return self.distribution_class(*distr_args)
+ else:
+ return Independent(self.distribution_class(*distr_args), 1)
+
+ def distribution(
+ self,
+ distr_args,
+ loc: Optional[torch.Tensor] = None,
+ scale: Optional[torch.Tensor] = None,
+ ) -> Distribution:
+ distr = self._base_distribution(distr_args)
+ if loc is None and scale is None:
+ return distr
+ else:
+ return AffineTransformed(distr, loc=loc, scale=scale, event_dim=self.event_dim)
+
+ @property
+ def event_shape(self) -> tuple:
+ r"""
+ Shape of each individual event contemplated by the distributions that this object constructs.
+ """
+ return () if self.dim == 1 else (self.dim,)
+
+ @property
+ def event_dim(self) -> int:
+ r"""
+ Number of event dimensions, i.e., length of the `event_shape` tuple, of the distributions that this object
+ constructs.
+ """
+ return len(self.event_shape)
+
+ @property
+ def value_in_support(self) -> float:
+ r"""
+ A float that will have a valid numeric value when computing the log-loss of the corresponding distribution. By
+ default 0.0. This value will be used when padding data series.
+ """
+ return 0.0
+
+ def get_parameter_projection(self, in_features: int) -> nn.Module:
+ r"""
+ Return the parameter projection layer that maps the input to the appropriate parameters of the distribution.
+ """
+ return ParameterProjection(
+ in_features=in_features,
+ args_dim=self.args_dim,
+ domain_map=LambdaLayer(self.domain_map),
+ )
+
+ def domain_map(self, *args: torch.Tensor):
+ r"""
+ Converts arguments to the right shape and domain. The domain depends on the type of distribution, while the
+ correct shape is obtained by reshaping the trailing axis in such a way that the returned tensors define a
+ distribution of the right event_shape.
+ """
+ raise NotImplementedError()
+
+ @staticmethod
+ def squareplus(x: torch.Tensor) -> torch.Tensor:
+ r"""
+ Helper to map inputs to the positive orthant by applying the square-plus operation. Reference:
+ https://twitter.com/jon_barron/status/1387167648669048833
+ """
+ return (x + torch.sqrt(torch.square(x) + 4.0)) / 2.0
+
+
+class StudentTOutput(DistributionOutput):
+ """
+ Student-T distribution output class.
+ """
+
+ args_dim: dict[str, int] = {"df": 1, "loc": 1, "scale": 1}
+ distribution_class: type = StudentT
+
+ @classmethod
+ def domain_map(cls, df: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor):
+ scale = cls.squareplus(scale).clamp_min(torch.finfo(scale.dtype).eps)
+ df = 2.0 + cls.squareplus(df)
+ return df.squeeze(-1), loc.squeeze(-1), scale.squeeze(-1)
+
+
+class NormalOutput(DistributionOutput):
+ """
+ Normal distribution output class.
+ """
+
+ args_dim: dict[str, int] = {"loc": 1, "scale": 1}
+ distribution_class: type = Normal
+
+ @classmethod
+ def domain_map(cls, loc: torch.Tensor, scale: torch.Tensor):
+ scale = cls.squareplus(scale).clamp_min(torch.finfo(scale.dtype).eps)
+ return loc.squeeze(-1), scale.squeeze(-1)
+
+
+class NegativeBinomialOutput(DistributionOutput):
+ """
+ Negative Binomial distribution output class.
+ """
+
+ args_dim: dict[str, int] = {"total_count": 1, "logits": 1}
+ distribution_class: type = NegativeBinomial
+
+ @classmethod
+ def domain_map(cls, total_count: torch.Tensor, logits: torch.Tensor):
+ total_count = cls.squareplus(total_count)
+ return total_count.squeeze(-1), logits.squeeze(-1)
+
+ def _base_distribution(self, distr_args) -> Distribution:
+ total_count, logits = distr_args
+ if self.dim == 1:
+ return self.distribution_class(total_count=total_count, logits=logits)
+ else:
+ return Independent(self.distribution_class(total_count=total_count, logits=logits), 1)
+
+ # Overwrites the parent class method. We cannot scale using the affine
+ # transformation since negative binomial should return integers. Instead
+ # we scale the parameters.
+ def distribution(
+ self, distr_args, loc: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None
+ ) -> Distribution:
+ total_count, logits = distr_args
+
+ if scale is not None:
+ # See scaling property of Gamma.
+ logits += scale.log()
+
+ return self._base_distribution((total_count, logits))
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tokenization_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tokenization_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b89e570931526a24555b4ff86ee1797f906f41fc
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tokenization_utils.py
@@ -0,0 +1,1135 @@
+# Copyright 2020 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Tokenization classes for python tokenizers. For fast tokenizers (provided by HuggingFace's tokenizers library) see
+tokenization_utils_fast.py
+"""
+
+import bisect
+import itertools
+import re
+import unicodedata
+from collections import OrderedDict
+from typing import Any, Optional, Union, overload
+
+from .tokenization_utils_base import (
+ ENCODE_KWARGS_DOCSTRING,
+ ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,
+ INIT_TOKENIZER_DOCSTRING,
+ AddedToken,
+ BatchEncoding,
+ EncodedInput,
+ EncodedInputPair,
+ PreTokenizedInput,
+ PreTokenizedInputPair,
+ PreTrainedTokenizerBase,
+ TextInput,
+ TextInputPair,
+ TruncationStrategy,
+)
+from .utils import PaddingStrategy, TensorType, add_end_docstrings, logging
+
+
+logger = logging.get_logger(__name__)
+
+# Slow tokenizers are saved in a vocabulary plus three separated files
+SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
+ADDED_TOKENS_FILE = "added_tokens.json"
+TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
+
+
+class Trie:
+ """
+ Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass
+ Loose reference https://en.wikipedia.org/wiki/Trie
+ """
+
+ def __init__(self, *args):
+ self.data = {}
+ self._tokens = set()
+ self._termination_char = ""
+ self.update(*args)
+
+ def update(self, *args):
+ """
+ Updates the Trie with new tokens provided as arguments.
+
+ Args:
+ *args: Variable number of words to be added to the Trie.
+ """
+ for token in tuple(*args):
+ self.add(token)
+
+ def add(self, word: str):
+ """
+ Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation.
+ The special key `""` in `self._termination_char` is used to represent termination.
+
+ This function is idempotent, adding twice the same word will leave the trie unchanged
+
+ Example:
+
+ ```python
+ >>> trie = Trie()
+ >>> trie.add("Hello 友達")
+ >>> trie.data
+ {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}}
+
+ >>> trie.add("Hello")
+ >>> trie.data
+ {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}}
+ ```
+ """
+ if not word:
+ # Prevent empty string
+ return
+
+ self._tokens.add(word)
+ ref = self.data
+ for char in word:
+ ref[char] = ref.setdefault(char, {})
+ ref = ref[char]
+ ref[self._termination_char] = 1
+
+ def split(self, text: str) -> list[str]:
+ """
+ Will look for the words added to the trie within `text`. Output is the original string split along the
+ boundaries of the words found.
+
+ This trie will match the longest possible word first !
+
+ Example:
+
+ ```python
+ >>> trie = Trie()
+ >>> trie.split("[CLS] This is a extra_id_100")
+ ["[CLS] This is a extra_id_100"]
+
+ >>> trie.add("[CLS]")
+ >>> trie.add("extra_id_1")
+ >>> trie.add("extra_id_100")
+ >>> trie.split("[CLS] This is a extra_id_100")
+ ["[CLS]", " This is a ", "extra_id_100"]
+ ```
+ """
+ # indexes are counted left of the chars index.
+ # "hello", index 0, is left of h, index 1 is between h and e.
+ # index 5 is right of the "o".
+
+ # States are going to capture every possible start (indexes as above)
+ # as keys, and have as values, a pointer to the position in the trie
+ # where we're at. This is a partial match for now.
+ # This enables to keep track of multiple matches while we're iterating
+ # the string
+ # If the trie contains, "blowing", and "lower" and we encounter the
+ # string "blower", we need to split into ["b", "lower"].
+ # This is where we need to keep track of multiple possible starts.
+ states = OrderedDict()
+
+ # This will contain every indices where we need
+ # to cut.
+ # We force to cut at offset 0 and len(text) (added later)
+ offsets = [0]
+
+ # This is used by the lookahead which needs to skip over
+ # some text where the full match exceeded the place in the initial
+ # for loop
+ skip = 0
+ # Main loop, Giving this algorithm O(n) complexity
+ for current, current_char in enumerate(text):
+ if skip and current < skip:
+ # Prevents the lookahead for matching twice
+ # like extra_id_100 and id_100
+ continue
+
+ # This will track every state
+ # that stop matching, we need to stop tracking them.
+ # If we look at "lowball", we're going to match "l" (add it to states), "o", "w", then
+ # fail on "b", we need to remove 0 from the valid states.
+ to_remove = set()
+ # Whenever we found a match, we need to drop everything
+ # this is a greedy algorithm, it will match on the first found token
+ reset = False
+
+ # In this case, we already have partial matches (But unfinished)
+ for start, trie_pointer in states.items():
+ if "" in trie_pointer:
+ # This is a final match, we need to reset and
+ # store the results in `offsets`.
+
+ # Lookahead to match longest first
+ # Important in case of extra_id_1 vs extra_id_100
+ # Here we are also actively looking for other earlier partial
+ # matches
+ # "[CLS]", "L", we need to match CLS even if L is special
+ for lookstart, looktrie_pointer in states.items():
+ if lookstart > start:
+ # This partial match is later, we can stop looking
+ break
+ elif lookstart < start:
+ # This partial match is earlier, the trie pointer
+ # was already updated, so index is + 1
+ lookahead_index = current + 1
+ end = current + 1
+ else:
+ # Here lookstart == start and
+ # looktrie_pointer == trie_pointer
+ # It wasn't updated yet so indices are current ones
+ lookahead_index = current
+ end = current
+ next_char = text[lookahead_index] if lookahead_index < len(text) else None
+ if "" in looktrie_pointer:
+ start = lookstart
+ end = lookahead_index
+ skip = lookahead_index
+
+ while next_char in looktrie_pointer:
+ looktrie_pointer = looktrie_pointer[next_char]
+ lookahead_index += 1
+ if "" in looktrie_pointer:
+ start = lookstart
+ end = lookahead_index
+ skip = lookahead_index
+
+ if lookahead_index == len(text):
+ # End of string
+ break
+ next_char = text[lookahead_index]
+ # End lookahead
+
+ # Storing and resetting
+ offsets.append(start)
+ offsets.append(end)
+ reset = True
+ break
+ elif current_char in trie_pointer:
+ # The current character being looked at has a match within the trie
+ # update the pointer (it will be stored back into states later).
+ trie_pointer = trie_pointer[current_char]
+
+ # Storing back the new pointer into the states.
+ # Partial matches got longer by one.
+ states[start] = trie_pointer
+ else:
+ # The new character has not match in the trie, we need
+ # to stop keeping track of this partial match.
+ # We can't do it directly within the loop because of how
+ # python iteration works
+ to_remove.add(start)
+
+ # Either clearing the full start (we found a real match)
+ # Or clearing only the partial matches that didn't work.
+ if reset:
+ states = {}
+ else:
+ for start in to_remove:
+ del states[start]
+
+ # If this character is a starting character within the trie
+ # start keeping track of this partial match.
+ if current >= skip and current_char in self.data:
+ states[current] = self.data[current_char]
+
+ # We have a cut at the end with states.
+ for start, trie_pointer in states.items():
+ if "" in trie_pointer:
+ # This is a final match, we need to reset and
+ # store the results in `offsets`.
+ end = len(text)
+ offsets.append(start)
+ offsets.append(end)
+ # Longest cut is always the one with lower start so the first
+ # item so we need to break.
+ break
+
+ return self.cut_text(text, offsets)
+
+ def cut_text(self, text, offsets):
+ # We have all the offsets now, we just need to do the actual splitting.
+ # We need to eventually add the first part of the string and the eventual
+ # last part.
+ offsets.append(len(text))
+ tokens = []
+ start = 0
+ for end in offsets:
+ if start > end:
+ logger.error(
+ "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it"
+ " anyway."
+ )
+ continue
+ elif start == end:
+ # This might happen if there's a match at index 0
+ # we're also preventing zero-width cuts in case of two
+ # consecutive matches
+ continue
+ tokens.append(text[start:end])
+ start = end
+
+ return tokens
+
+
+class ExtensionsTrie(Trie):
+ def __init__(self, *args):
+ super().__init__(*args)
+
+ def extensions(self, prefix: str):
+ """
+ Generates all extensions of a given prefix token in the Trie.
+
+ Example:
+
+ ```python
+ >>> trie = Trie()
+ >>> trie.add("apple")
+ >>> trie.add("app")
+ >>> trie.add("application")
+ >>> trie.extensions("app")
+ ['app', 'apple', 'application']
+ ```
+ """
+ prefix_node = self._get_node(prefix)
+ ret = self._collect_tokens(prefix_node)
+ return [prefix + token for token in ret]
+
+ def _get_node(self, token: str) -> dict:
+ """
+ Retrieves the node corresponding to the given token in the Trie.
+
+ Args:
+ token (str): The token for which the corresponding node needs to be retrieved.
+
+ Returns:
+ dict: The node in the Trie corresponding to the given token.
+ """
+ node = self.data
+ for char in token:
+ if char not in node:
+ break
+
+ node = node[char]
+ return node
+
+ def _collect_tokens(self, node: dict) -> list:
+ """
+ Generates all tokens in the Trie starting from a given node.
+
+ Args:
+ node (dict): The node in the Trie from which tokens need to be generated.
+
+ Returns:
+ list: List of tokens generated from the given node.
+ """
+ tokens = [self._termination_char] if self._termination_char in node else []
+ for token, subtrie_head in node.items():
+ if token != self._termination_char:
+ subtokens = self._collect_tokens(subtrie_head)
+ tokens.extend([token + subtoken for subtoken in subtokens])
+ return tokens
+
+
+def _is_whitespace(char):
+ """Checks whether `char` is a whitespace character."""
+ # \t, \n, and \r are technically control characters but we treat them
+ # as whitespace since they are generally considered as such.
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
+ return True
+ cat = unicodedata.category(char)
+ if cat == "Zs":
+ return True
+ return False
+
+
+def _is_control(char):
+ """Checks whether `char` is a control character."""
+ # These are technically control characters but we count them as whitespace
+ # characters.
+ if char == "\t" or char == "\n" or char == "\r":
+ return False
+ cat = unicodedata.category(char)
+ if cat.startswith("C"):
+ return True
+ return False
+
+
+def _is_punctuation(char):
+ """Checks whether `char` is a punctuation character."""
+ cp = ord(char)
+ # We treat all non-letter/number ASCII as punctuation.
+ # Characters such as "^", "$", and "`" are not in the Unicode
+ # Punctuation class but we treat them as punctuation anyways, for
+ # consistency.
+ if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
+ return True
+ cat = unicodedata.category(char)
+ if cat.startswith("P"):
+ return True
+ return False
+
+
+def _is_end_of_word(text):
+ """Checks whether the last character in text is one of a punctuation, control or whitespace character."""
+ last_char = text[-1]
+ return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char))
+
+
+def _is_start_of_word(text):
+ """Checks whether the first character in text is one of a punctuation, control or whitespace character."""
+ first_char = text[0]
+ return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char))
+
+
+def _insert_one_token_to_ordered_list(token_list: list[str], new_token: str):
+ """
+ Inserts one token to an ordered list if it does not already exist. Note: token_list must be sorted.
+ """
+ insertion_idx = bisect.bisect_left(token_list, new_token)
+ # Checks if new_token is already in the ordered token_list
+ if insertion_idx < len(token_list) and token_list[insertion_idx] == new_token:
+ # new_token is in token_list, don't add
+ return
+ else:
+ token_list.insert(insertion_idx, new_token)
+
+
+@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
+class PreTrainedTokenizer(PreTrainedTokenizerBase):
+ """
+ Base class for all slow tokenizers.
+
+ Inherits from [`~tokenization_utils_base.PreTrainedTokenizerBase`].
+
+ Handle all the shared methods for tokenization and special tokens as well as methods downloading/caching/loading
+ pretrained tokenizers as well as adding tokens to the vocabulary.
+
+ This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the
+ specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
+ """
+
+ def __init__(self, **kwargs):
+ # 1. Init the parent class
+
+ self.tokens_trie = Trie()
+
+ # 2. init `_added_tokens_decoder` if child class did not
+ if not hasattr(self, "_added_tokens_decoder"):
+ self._added_tokens_decoder: dict[int, AddedToken] = {}
+
+ # 3. if a `added_tokens_decoder` is passed, we are loading from a saved tokenizer, we overwrite
+ self._added_tokens_decoder.update(kwargs.pop("added_tokens_decoder", {}))
+ self._added_tokens_encoder: dict[str, int] = {k.content: v for v, k in self._added_tokens_decoder.items()}
+
+ # 4 init the parent class
+ super().__init__(**kwargs)
+
+ # 4. If some of the special tokens are not part of the vocab, we add them, at the end.
+ # the order of addition is the same as self.SPECIAL_TOKENS_ATTRIBUTES following `tokenizers`
+ self._add_tokens(
+ [token for token in self.all_special_tokens_extended if token not in self._added_tokens_encoder],
+ special_tokens=True,
+ )
+
+ self._decode_use_source_tokenizer = False
+
+ @property
+ def is_fast(self) -> bool:
+ return False
+
+ @property
+ def vocab_size(self) -> int:
+ """
+ `int`: Size of the base vocabulary (without the added tokens).
+ """
+ raise NotImplementedError
+
+ @property
+ def added_tokens_encoder(self) -> dict[str, int]:
+ """
+ Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
+ optimisation in `self._added_tokens_encoder` for the slow tokenizers.
+ """
+ return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])}
+
+ @property
+ def added_tokens_decoder(self) -> dict[int, AddedToken]:
+ """
+ Returns the added tokens in the vocabulary as a dictionary of index to AddedToken.
+
+ Returns:
+ `dict[str, int]`: The added tokens.
+ """
+ return dict(sorted(self._added_tokens_decoder.items(), key=lambda item: item[0]))
+
+ @added_tokens_decoder.setter
+ def added_tokens_decoder(self, value: dict[int, Union[AddedToken, str]]) -> dict[int, AddedToken]:
+ # Always raise an error if string because users should define the behavior
+ for index, token in value.items():
+ if not isinstance(token, (str, AddedToken)) or not isinstance(index, int):
+ raise TypeError(
+ f"The provided `added_tokens_decoder` has an element of type {index.__class__, token.__class__}, should be a dict of {int, Union[AddedToken, str]}"
+ )
+
+ self._added_tokens_decoder[index] = AddedToken(token) if isinstance(token, str) else token
+ self._added_tokens_encoder[str(token)] = index
+ self._update_total_vocab_size()
+
+ def get_added_vocab(self) -> dict[str, int]:
+ """
+ Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from
+ the fast call because for now we always add the tokens even if they are already in the vocabulary. This is
+ something we should change.
+
+ Returns:
+ `dict[str, int]`: The added tokens.
+ """
+ return self._added_tokens_encoder
+
+ def __len__(self):
+ """
+ Size of the full vocabulary with the added tokens.
+ """
+ return self.total_vocab_size
+
+ def _update_total_vocab_size(self):
+ """
+ Update the size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because
+ otherwise if there is a hole in the vocab, we will add tokenizers at a wrong index. This operation is slow and
+ is only updated when adding tokens.
+ """
+ self.total_vocab_size = len(self.get_vocab())
+
+ def _add_tokens(self, new_tokens: Union[list[str], list[AddedToken]], special_tokens: bool = False) -> int:
+ """
+ Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to
+ it with indices starting from length of the current vocabulary. Special tokens are sometimes already in the
+ vocab which is why they have to be handled specifically.
+
+ Args:
+ new_tokens (`list[str]`or `list[tokenizers.AddedToken]`):
+ Token(s) to add in vocabulary. A token is counted as added if it's not already in the vocabulary
+ (tested by checking if the tokenizer assign the index of the `unk_token` to them). If a token is part
+ of the vocabulary then we simply mark this token as an `AddedToken` which allows to control the
+ stripping and normalization of this token. This is NOT possible in `tokenizers`.
+ special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the tokens should be added as special tokens.
+
+ Returns:
+ `int`: The number of tokens actually added to the vocabulary.
+
+ Examples:
+
+ ```python
+ # Let's see how to increase the vocabulary of Bert model and tokenizer
+ tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
+ model = BertModel.from_pretrained("google-bert/bert-base-uncased")
+
+ num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"])
+ print("We have added", num_added_toks, "tokens")
+ # Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
+ model.resize_token_embeddings(len(tokenizer))
+ ```"""
+ added_tokens = 0
+ if new_tokens is None:
+ return added_tokens
+ # TODO this is fairly slow to improve!
+ current_vocab = self.get_vocab().copy()
+ new_idx = len(current_vocab) # only call this once, len gives the last index + 1
+ for token in new_tokens:
+ if not isinstance(token, (str, AddedToken)):
+ raise TypeError(f"Token {token} is not a string but a {type(token)}.")
+ if str(token) == "":
+ continue
+ if isinstance(token, str):
+ if token in self._added_tokens_encoder:
+ continue
+ else:
+ # very important for fast and slow equivalence!
+ is_special = token in self.all_special_tokens or special_tokens
+ token = AddedToken(
+ token, rstrip=False, lstrip=False, normalized=not is_special, special=is_special
+ )
+ elif special_tokens:
+ # doing token.special=True changes the normalization! will fix in rust
+ # this is important and the only reason why the AddedTokens in each class are normalized by default
+ token.__setstate__({"special": True, "normalized": token.normalized})
+ if token in self._added_tokens_decoder:
+ continue
+ if not token.special and token.normalized and getattr(self, "do_lower_case", False):
+ # Normalize if requested
+ token.content = token.content.lower()
+ if token.content not in current_vocab:
+ token_index = new_idx + added_tokens
+ current_vocab[token.content] = token_index
+ added_tokens += 1
+ else:
+ token_index = current_vocab[token.content]
+
+ if token.special and str(token) not in self.all_special_tokens:
+ self._special_tokens_map["additional_special_tokens"].append(token)
+ # the setter automatically updates the reverse map
+ self._added_tokens_decoder[token_index] = token
+ self._added_tokens_encoder[token.content] = token_index
+ if self.verbose:
+ logger.info(f"Adding {token} to the vocabulary")
+
+ self._update_trie()
+ self._update_total_vocab_size()
+ return added_tokens
+
+ def _update_trie(self, unique_no_split_tokens: Optional[list[str]] = None):
+ for token in self._added_tokens_decoder.values():
+ if token.content not in self.tokens_trie._tokens:
+ self.tokens_trie.add(token.content)
+ for token in unique_no_split_tokens or []:
+ if token not in self.tokens_trie._tokens:
+ self.tokens_trie.add(token)
+
+ def num_special_tokens_to_add(self, pair: bool = False) -> int:
+ """
+ Returns the number of added tokens when encoding a sequence with special tokens.
+
+
+
+ This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not put
+ this inside your training loop.
+
+
+
+ Args:
+ pair (`bool`, *optional*, defaults to `False`):
+ Whether the number of added tokens should be computed in the case of a sequence pair or a single
+ sequence.
+
+ Returns:
+ `int`: Number of special tokens added to sequences.
+ """
+ token_ids_0 = []
+ token_ids_1 = []
+ return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
+
+ def tokenize(self, text: TextInput, **kwargs) -> list[str]:
+ """
+ Converts a string into a sequence of tokens, using the tokenizer.
+
+ Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies
+ (BPE/SentencePieces/WordPieces). Takes care of added tokens.
+
+ Args:
+ text (`str`):
+ The sequence to be encoded.
+ **kwargs (additional keyword arguments):
+ Passed along to the model-specific `prepare_for_tokenization` preprocessing method.
+
+ Returns:
+ `list[str]`: The list of tokens.
+ """
+ split_special_tokens = kwargs.pop("split_special_tokens", self.split_special_tokens)
+
+ text, kwargs = self.prepare_for_tokenization(text, **kwargs)
+
+ if kwargs:
+ logger.warning(f"Keyword arguments {kwargs} not recognized.")
+
+ if hasattr(self, "do_lower_case") and self.do_lower_case:
+ # convert non-special tokens to lowercase. Might be super slow as well?
+ escaped_special_toks = [re.escape(s_tok) for s_tok in (self.all_special_tokens)]
+ escaped_special_toks += [
+ re.escape(s_tok.content)
+ for s_tok in (self._added_tokens_decoder.values())
+ if not s_tok.special and s_tok.normalized
+ ]
+ pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
+ text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
+
+ if split_special_tokens:
+ no_split_token = []
+ tokens = [text]
+ else:
+ no_split_token = self._added_tokens_encoder.keys() # don't split on any of the added tokens
+ # "This is something else"
+ tokens = self.tokens_trie.split(text)
+
+ # ["This is something", "", " else"]
+ for i, token in enumerate(tokens):
+ if token in no_split_token:
+ tok_extended = self._added_tokens_decoder.get(self._added_tokens_encoder[token], None)
+ left = tokens[i - 1] if i > 0 else None
+ right = tokens[i + 1] if i < len(tokens) - 1 else None
+ if isinstance(tok_extended, AddedToken):
+ if tok_extended.rstrip and right:
+ # A bit counter-intuitive but we strip the left of the string
+ # since tok_extended.rstrip means the special token is eating all white spaces on its right
+ tokens[i + 1] = right.lstrip()
+ # Strip white spaces on the left
+ if tok_extended.lstrip and left:
+ tokens[i - 1] = left.rstrip() # Opposite here
+ if tok_extended.single_word and left and left[-1] != " ":
+ tokens[i - 1] += token
+ tokens[i] = ""
+ elif tok_extended.single_word and right and right[0] != " ":
+ tokens[i + 1] = token + tokens[i + 1]
+ tokens[i] = ""
+ else:
+ raise ValueError(
+ f"{tok_extended} cannot be tokenized because it was not properly added"
+ f" to the tokenizer. This means that it is not an `AddedToken` but a {type(tok_extended)}"
+ )
+ # ["This is something", "", "else"]
+ tokenized_text = []
+ for token in tokens:
+ # Need to skip eventual empty (fully stripped) tokens
+ if not token:
+ continue
+ if token in no_split_token:
+ tokenized_text.append(token)
+ else:
+ tokenized_text.extend(self._tokenize(token))
+ # ["This", " is", " something", "", "else"]
+ return tokenized_text
+
+ def _tokenize(self, text, **kwargs):
+ """
+ Converts a string into a sequence of tokens (string), using the tokenizer. Split in words for word-based
+ vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
+
+ Do NOT take care of added tokens.
+ """
+ raise NotImplementedError
+
+ def convert_tokens_to_ids(self, tokens: Union[str, list[str]]) -> Union[int, list[int]]:
+ """
+ Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
+ vocabulary.
+
+ Args:
+ tokens (`str` or `list[str]`): One or several token(s) to convert to token id(s).
+
+ Returns:
+ `int` or `list[int]`: The token id or list of token ids.
+ """
+ if tokens is None:
+ return None
+
+ if isinstance(tokens, str):
+ return self._convert_token_to_id_with_added_voc(tokens)
+
+ ids = []
+ for token in tokens:
+ ids.append(self._convert_token_to_id_with_added_voc(token))
+ return ids
+
+ def _convert_token_to_id_with_added_voc(self, token):
+ if token is None:
+ return None
+
+ if token in self._added_tokens_encoder:
+ return self._added_tokens_encoder[token]
+ return self._convert_token_to_id(token)
+
+ def _convert_token_to_id(self, token):
+ raise NotImplementedError
+
+ def _encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput, EncodedInput],
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ is_split_into_words: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ def get_input_ids(text):
+ if isinstance(text, str):
+ tokens = self.tokenize(text, **kwargs)
+ return self.convert_tokens_to_ids(tokens)
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
+ if is_split_into_words:
+ tokens = list(
+ itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text))
+ )
+ return self.convert_tokens_to_ids(tokens)
+ else:
+ return self.convert_tokens_to_ids(text)
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
+ return text
+ else:
+ if is_split_into_words:
+ raise ValueError(
+ f"Input {text} is not valid. Should be a string or a list/tuple of strings when"
+ " `is_split_into_words=True`."
+ )
+ else:
+ raise ValueError(
+ f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of"
+ " integers."
+ )
+
+ if return_offsets_mapping:
+ raise NotImplementedError(
+ "return_offset_mapping is not available when using Python tokenizers. "
+ "To use this feature, change your tokenizer to one deriving from "
+ "transformers.PreTrainedTokenizerFast. "
+ "More information on available tokenizers at "
+ "https://github.com/huggingface/transformers/pull/2674"
+ )
+
+ first_ids = get_input_ids(text)
+ second_ids = get_input_ids(text_pair) if text_pair is not None else None
+
+ return self.prepare_for_model(
+ first_ids,
+ pair_ids=second_ids,
+ add_special_tokens=add_special_tokens,
+ padding=padding_strategy.value,
+ truncation=truncation_strategy.value,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ prepend_batch_axis=True,
+ return_attention_mask=return_attention_mask,
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ verbose=verbose,
+ )
+
+ def _batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ list[TextInput],
+ list[TextInputPair],
+ list[PreTokenizedInput],
+ list[PreTokenizedInputPair],
+ list[EncodedInput],
+ list[EncodedInputPair],
+ ],
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ is_split_into_words: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ split_special_tokens: bool = False,
+ **kwargs,
+ ) -> BatchEncoding:
+ def get_input_ids(text):
+ if isinstance(text, str):
+ tokens = self.tokenize(text, **kwargs)
+ return self.convert_tokens_to_ids(tokens)
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
+ if is_split_into_words:
+ tokens = list(
+ itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text))
+ )
+ return self.convert_tokens_to_ids(tokens)
+ else:
+ return self.convert_tokens_to_ids(text)
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
+ return text
+ else:
+ raise ValueError(
+ "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
+ )
+
+ if return_offsets_mapping:
+ raise NotImplementedError(
+ "return_offset_mapping is not available when using Python tokenizers. "
+ "To use this feature, change your tokenizer to one deriving from "
+ "transformers.PreTrainedTokenizerFast."
+ )
+
+ input_ids = []
+ for ids_or_pair_ids in batch_text_or_text_pairs:
+ if (
+ not isinstance(ids_or_pair_ids, (list, tuple))
+ or is_split_into_words
+ and not isinstance(ids_or_pair_ids[0], (list, tuple))
+ ):
+ ids, pair_ids = ids_or_pair_ids, None
+ else:
+ ids, pair_ids = ids_or_pair_ids
+
+ first_ids = get_input_ids(ids)
+ second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
+ input_ids.append((first_ids, second_ids))
+
+ batch_outputs = self._batch_prepare_for_model(
+ input_ids,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_attention_mask=return_attention_mask,
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ return_tensors=return_tensors,
+ verbose=verbose,
+ split_special_tokens=split_special_tokens,
+ )
+
+ return BatchEncoding(batch_outputs)
+
+ @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def _batch_prepare_for_model(
+ self,
+ batch_ids_pairs: list[Union[PreTokenizedInputPair, tuple[list[int], None]]],
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[str] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ split_special_tokens: bool = False,
+ ) -> BatchEncoding:
+ """
+ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
+ adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
+ manages a moving window (with user defined stride) for overflowing tokens
+
+ Args:
+ batch_ids_pairs: list of tokenized input ids or input ids pairs
+ """
+
+ batch_outputs = {}
+ for first_ids, second_ids in batch_ids_pairs:
+ outputs = self.prepare_for_model(
+ first_ids,
+ second_ids,
+ add_special_tokens=add_special_tokens,
+ padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward
+ truncation=truncation_strategy.value,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=None, # we pad in batch afterward
+ padding_side=None, # we pad in batch afterward
+ return_attention_mask=False, # we pad in batch afterward
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ return_tensors=None, # We convert the whole batch to tensors at the end
+ prepend_batch_axis=False,
+ verbose=verbose,
+ split_special_tokens=split_special_tokens,
+ )
+
+ for key, value in outputs.items():
+ if key not in batch_outputs:
+ batch_outputs[key] = []
+ batch_outputs[key].append(value)
+
+ batch_outputs = self.pad(
+ batch_outputs,
+ padding=padding_strategy.value,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_attention_mask=return_attention_mask,
+ )
+
+ batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
+
+ return batch_outputs
+
+ def prepare_for_tokenization(
+ self, text: str, is_split_into_words: bool = False, **kwargs
+ ) -> tuple[str, dict[str, Any]]:
+ """
+ Performs any necessary transformations before tokenization.
+
+ This method should pop the arguments from kwargs and return the remaining `kwargs` as well. We test the
+ `kwargs` at the end of the encoding process to be sure all the arguments have been used.
+
+ Args:
+ text (`str`):
+ The text to prepare.
+ is_split_into_words (`bool`, *optional*, defaults to `False`):
+ Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
+ tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
+ which it will tokenize. This is useful for NER or token classification.
+ kwargs (`dict[str, Any]`, *optional*):
+ Keyword arguments to use for the tokenization.
+
+ Returns:
+ `tuple[str, dict[str, Any]]`: The prepared text and the unused kwargs.
+ """
+ return (text, kwargs)
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list, token_ids_1: Optional[list] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of ids of the first sequence.
+ token_ids_1 (`list[int]`, *optional*):
+ List of ids of the second sequence.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ if token_ids_1 is not None:
+ raise ValueError(
+ "You should not supply a second sequence if the provided sequence of "
+ "ids is already formatted with special tokens for the model."
+ )
+
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+ return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
+
+ @overload
+ def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: ...
+
+ @overload
+ def convert_ids_to_tokens(self, ids: list[int], skip_special_tokens: bool = False) -> list[str]: ...
+
+ def convert_ids_to_tokens(
+ self, ids: Union[int, list[int]], skip_special_tokens: bool = False
+ ) -> Union[str, list[str]]:
+ """
+ Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
+ added tokens.
+
+ Args:
+ ids (`int` or `list[int]`):
+ The token id (or token ids) to convert to tokens.
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not to remove special tokens in the decoding.
+
+ Returns:
+ `str` or `list[str]`: The decoded token(s).
+ """
+ if isinstance(ids, int):
+ if ids in self._added_tokens_decoder:
+ return self._added_tokens_decoder[ids].content
+ else:
+ return self._convert_id_to_token(ids)
+ tokens = []
+ for index in ids:
+ index = int(index)
+ if skip_special_tokens and index in self.all_special_ids:
+ continue
+ if index in self._added_tokens_decoder:
+ tokens.append(self._added_tokens_decoder[index].content)
+ else:
+ tokens.append(self._convert_id_to_token(index))
+ return tokens
+
+ def _convert_id_to_token(self, index: int) -> str:
+ raise NotImplementedError
+
+ def convert_tokens_to_string(self, tokens: list[str]) -> str:
+ return " ".join(tokens)
+
+ def _decode(
+ self,
+ token_ids: Union[int, list[int]],
+ skip_special_tokens: bool = False,
+ clean_up_tokenization_spaces: Optional[bool] = None,
+ spaces_between_special_tokens: bool = True,
+ **kwargs,
+ ) -> str:
+ self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
+
+ filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
+ # If given is a single id, prevents splitting the string in upcoming loop
+ if isinstance(filtered_tokens, str):
+ filtered_tokens = [filtered_tokens]
+
+ legacy_added_tokens = set(self._added_tokens_encoder.keys()) - set(self.all_special_tokens) | {
+ token for token in self.additional_special_tokens if self.convert_tokens_to_ids(token) >= self.vocab_size
+ }
+ # To avoid mixing byte-level and unicode for byte-level BPT
+ # we need to build string separately for added tokens and byte-level tokens
+ # cf. https://github.com/huggingface/transformers/issues/1133
+ sub_texts = []
+ current_sub_text = []
+ # TODO @ArthurZ in version 5, special tokens should be handled in convert_tokens_to_string, while _convert_tokens_to_string
+ for token in filtered_tokens:
+ if skip_special_tokens and token in self.all_special_tokens:
+ continue
+ if token in legacy_added_tokens:
+ if current_sub_text:
+ string = self.convert_tokens_to_string(current_sub_text)
+ if len(string) > 0:
+ sub_texts.append(string)
+ current_sub_text = []
+ sub_texts.append(token)
+ else:
+ current_sub_text.append(token)
+ if current_sub_text:
+ sub_texts.append(self.convert_tokens_to_string(current_sub_text))
+
+ if spaces_between_special_tokens:
+ text = " ".join(sub_texts)
+ else:
+ text = "".join(sub_texts)
+
+ clean_up_tokenization_spaces = (
+ clean_up_tokenization_spaces
+ if clean_up_tokenization_spaces is not None
+ else self.clean_up_tokenization_spaces
+ )
+ if clean_up_tokenization_spaces:
+ clean_text = self.clean_up_tokenization(text)
+ return clean_text
+ else:
+ return text
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tokenization_utils_base.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tokenization_utils_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dbbe296aa22c87e7dc4c13d1e1ccd4de390c6f0
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tokenization_utils_base.py
@@ -0,0 +1,4366 @@
+# coding=utf-8
+# Copyright 2020 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Base classes common to both the slow and the fast tokenization classes: PreTrainedTokenizerBase (host all the user
+fronting encoding methods) Special token mixing (host the special tokens logic) and BatchEncoding (wrap the dictionary
+of output with special method for the Fast tokenizers)
+"""
+
+import copy
+import json
+import os
+import re
+import warnings
+from collections import UserDict
+from collections.abc import Mapping, Sequence, Sized
+from contextlib import contextmanager
+from dataclasses import dataclass
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional, Union
+
+import numpy as np
+from huggingface_hub import list_repo_files
+from packaging import version
+
+from . import __version__
+from .dynamic_module_utils import custom_object_save
+from .utils import (
+ CHAT_TEMPLATE_DIR,
+ CHAT_TEMPLATE_FILE,
+ ExplicitEnum,
+ PaddingStrategy,
+ PushToHubMixin,
+ TensorType,
+ add_end_docstrings,
+ cached_file,
+ copy_func,
+ download_url,
+ extract_commit_hash,
+ is_flax_available,
+ is_jax_tensor,
+ is_mlx_available,
+ is_numpy_array,
+ is_offline_mode,
+ is_protobuf_available,
+ is_remote_url,
+ is_tf_available,
+ is_tf_tensor,
+ is_tokenizers_available,
+ is_torch_available,
+ is_torch_device,
+ is_torch_tensor,
+ list_repo_templates,
+ logging,
+ requires_backends,
+ to_py_obj,
+)
+from .utils.chat_template_utils import render_jinja_template
+from .utils.import_utils import PROTOBUF_IMPORT_ERROR
+
+
+if TYPE_CHECKING:
+ if is_torch_available():
+ import torch
+ if is_tf_available():
+ import tensorflow as tf
+ if is_flax_available():
+ import jax.numpy as jnp # noqa: F401
+
+
+def import_protobuf_decode_error(error_message=""):
+ if is_protobuf_available():
+ from google.protobuf.message import DecodeError
+
+ return DecodeError
+ else:
+ raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
+
+
+def flatten(arr: list):
+ res = []
+ if len(arr) > 0:
+ for sub_arr in arr:
+ if isinstance(arr[0], (list, tuple)):
+ res.extend(flatten(sub_arr))
+ else:
+ res.append(sub_arr)
+ return res
+
+
+if is_tokenizers_available() or TYPE_CHECKING:
+ from tokenizers import Encoding as EncodingFast
+
+if is_tokenizers_available():
+ from tokenizers import AddedToken
+else:
+
+ @dataclass(frozen=False, eq=True)
+ class AddedToken:
+ """
+ AddedToken represents a token to be added to a Tokenizer An AddedToken can have special options defining the
+ way it should behave.
+
+ The `normalized` will default to `not special` if it is not specified, similarly to the definition in
+ `tokenizers`.
+ """
+
+ def __init__(
+ self, content: str, single_word=False, lstrip=False, rstrip=False, special=False, normalized=None
+ ):
+ self.content = content
+ self.single_word = single_word
+ self.lstrip = lstrip
+ self.rstrip = rstrip
+ self.special = special
+ self.normalized = normalized if normalized is not None else not special
+
+ def __getstate__(self):
+ return self.__dict__
+
+ def __str__(self):
+ return self.content
+
+
+logger = logging.get_logger(__name__)
+
+VERY_LARGE_INTEGER = int(1e30) # This is used to set the max input length for a model with infinite size input
+LARGE_INTEGER = int(1e20) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
+
+# Define type aliases and NamedTuples
+TextInput = str
+PreTokenizedInput = list[str]
+EncodedInput = list[int]
+TextInputPair = tuple[str, str]
+PreTokenizedInputPair = tuple[list[str], list[str]]
+EncodedInputPair = tuple[list[int], list[int]]
+
+# Define type aliases for text-related non-text modalities
+AudioInput = Union[np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]]
+
+# Slow tokenizers used to be saved in three separated files
+SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
+ADDED_TOKENS_FILE = "added_tokens.json"
+TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
+
+# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
+FULL_TOKENIZER_FILE = "tokenizer.json"
+_re_tokenizer_file = re.compile(r"tokenizer\.(.*)\.json")
+
+
+class TruncationStrategy(ExplicitEnum):
+ """
+ Possible values for the `truncation` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in
+ an IDE.
+ """
+
+ ONLY_FIRST = "only_first"
+ ONLY_SECOND = "only_second"
+ LONGEST_FIRST = "longest_first"
+ DO_NOT_TRUNCATE = "do_not_truncate"
+
+
+class CharSpan(NamedTuple):
+ """
+ Character span in the original string.
+
+ Args:
+ start (`int`): Index of the first character in the original string.
+ end (`int`): Index of the character following the last character in the original string.
+ """
+
+ start: int
+ end: int
+
+
+class TokenSpan(NamedTuple):
+ """
+ Token span in an encoded string (list of tokens).
+
+ Args:
+ start (`int`): Index of the first token in the span.
+ end (`int`): Index of the token following the last token in the span.
+ """
+
+ start: int
+ end: int
+
+
+class BatchEncoding(UserDict):
+ """
+ Holds the output of the [`~tokenization_utils_base.PreTrainedTokenizerBase.__call__`],
+ [`~tokenization_utils_base.PreTrainedTokenizerBase.encode_plus`] and
+ [`~tokenization_utils_base.PreTrainedTokenizerBase.batch_encode_plus`] methods (tokens, attention_masks, etc).
+
+ This class is derived from a python dictionary and can be used as a dictionary. In addition, this class exposes
+ utility methods to map from word/character space to token space.
+
+ Args:
+ data (`dict`, *optional*):
+ Dictionary of lists/arrays/tensors returned by the `__call__`/`encode_plus`/`batch_encode_plus` methods
+ ('input_ids', 'attention_mask', etc.).
+ encoding (`tokenizers.Encoding` or `Sequence[tokenizers.Encoding]`, *optional*):
+ If the tokenizer is a fast tokenizer which outputs additional information like mapping from word/character
+ space to token space the `tokenizers.Encoding` instance or list of instance (for batches) hold this
+ information.
+ tensor_type (`Union[None, str, TensorType]`, *optional*):
+ You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
+ initialization.
+ prepend_batch_axis (`bool`, *optional*, defaults to `False`):
+ Whether or not to add a batch axis when converting to tensors (see `tensor_type` above). Note that this
+ parameter has an effect if the parameter `tensor_type` is set, *otherwise has no effect*.
+ n_sequences (`Optional[int]`, *optional*):
+ You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
+ initialization.
+ """
+
+ def __init__(
+ self,
+ data: Optional[dict[str, Any]] = None,
+ encoding: Optional[Union[EncodingFast, Sequence[EncodingFast]]] = None,
+ tensor_type: Union[None, str, TensorType] = None,
+ prepend_batch_axis: bool = False,
+ n_sequences: Optional[int] = None,
+ ):
+ super().__init__(data)
+
+ # If encoding is not None, the fast tokenization is used
+ if encoding is not None and isinstance(encoding, EncodingFast):
+ encoding = [encoding]
+
+ self._encodings = encoding
+
+ if n_sequences is None and encoding is not None and encoding:
+ n_sequences = encoding[0].n_sequences
+
+ self._n_sequences = n_sequences
+
+ self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis)
+
+ @property
+ def n_sequences(self) -> Optional[int]:
+ """
+ `Optional[int]`: The number of sequences used to generate each sample from the batch encoded in this
+ [`BatchEncoding`]. Currently can be one of `None` (unknown), `1` (a single sentence) or `2` (a pair of
+ sentences)
+ """
+ return self._n_sequences
+
+ @property
+ def is_fast(self) -> bool:
+ """
+ `bool`: Indicate whether this [`BatchEncoding`] was generated from the result of a [`PreTrainedTokenizerFast`]
+ or not.
+ """
+ return self._encodings is not None
+
+ def __getitem__(self, item: Union[int, str]) -> Union[Any, EncodingFast]:
+ """
+ If the key is a string, returns the value of the dict associated to `key` ('input_ids', 'attention_mask',
+ etc.).
+
+ If the key is an integer, get the `tokenizers.Encoding` for batch item with index `key`.
+
+ If the key is a slice, returns the value of the dict associated to `key` ('input_ids', 'attention_mask', etc.)
+ with the constraint of slice.
+ """
+ if isinstance(item, str):
+ return self.data[item]
+ elif self._encodings is not None:
+ return self._encodings[item]
+ elif isinstance(item, slice):
+ return {key: self.data[key][item] for key in self.data}
+ else:
+ raise KeyError(
+ "Invalid key. Only three types of key are available: "
+ "(1) string, (2) integers for backend Encoding, and (3) slices for data subsetting."
+ )
+
+ def __getattr__(self, item: str):
+ try:
+ return self.data[item]
+ except KeyError:
+ raise AttributeError
+
+ def __getstate__(self):
+ return {"data": self.data, "encodings": self._encodings}
+
+ def __setstate__(self, state):
+ if "data" in state:
+ self.data = state["data"]
+
+ if "encodings" in state:
+ self._encodings = state["encodings"]
+
+ # After this point:
+ # Extended properties and methods only available for fast (Rust-based) tokenizers
+ # provided by HuggingFace tokenizers library.
+
+ @property
+ def encodings(self) -> Optional[list[EncodingFast]]:
+ """
+ `Optional[list[tokenizers.Encoding]]`: The list all encodings from the tokenization process. Returns `None` if
+ the input was tokenized through Python (i.e., not a fast) tokenizer.
+ """
+ return self._encodings
+
+ def tokens(self, batch_index: int = 0) -> list[str]:
+ """
+ Return the list of tokens (sub-parts of the input strings after word/subword splitting and before conversion to
+ integer indices) at a given batch index (only works for the output of a fast tokenizer).
+
+ Args:
+ batch_index (`int`, *optional*, defaults to 0): The index to access in the batch.
+
+ Returns:
+ `list[str]`: The list of tokens at that index.
+ """
+ if not self._encodings:
+ raise ValueError(
+ "tokens() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`"
+ " class)."
+ )
+ return self._encodings[batch_index].tokens
+
+ def sequence_ids(self, batch_index: int = 0) -> list[Optional[int]]:
+ """
+ Return a list mapping the tokens to the id of their original sentences:
+
+ - `None` for special tokens added around or between sequences,
+ - `0` for tokens corresponding to words in the first sequence,
+ - `1` for tokens corresponding to words in the second sequence when a pair of sequences was jointly
+ encoded.
+
+ Args:
+ batch_index (`int`, *optional*, defaults to 0): The index to access in the batch.
+
+ Returns:
+ `list[Optional[int]]`: A list indicating the sequence id corresponding to each token. Special tokens added
+ by the tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding
+ sequence.
+ """
+ if not self._encodings:
+ raise ValueError(
+ "sequence_ids() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`"
+ " class)."
+ )
+ return self._encodings[batch_index].sequence_ids
+
+ def words(self, batch_index: int = 0) -> list[Optional[int]]:
+ """
+ Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer.
+
+ Args:
+ batch_index (`int`, *optional*, defaults to 0): The index to access in the batch.
+
+ Returns:
+ `list[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by the
+ tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding word
+ (several tokens will be mapped to the same word index if they are parts of that word).
+ """
+ if not self._encodings:
+ raise ValueError(
+ "words() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`"
+ " class)."
+ )
+ warnings.warn(
+ "`BatchEncoding.words()` property is deprecated and should be replaced with the identical, "
+ "but more self-explanatory `BatchEncoding.word_ids()` property.",
+ FutureWarning,
+ )
+ return self.word_ids(batch_index)
+
+ def word_ids(self, batch_index: int = 0) -> list[Optional[int]]:
+ """
+ Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer.
+
+ Args:
+ batch_index (`int`, *optional*, defaults to 0): The index to access in the batch.
+
+ Returns:
+ `list[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by the
+ tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding word
+ (several tokens will be mapped to the same word index if they are parts of that word).
+ """
+ if not self._encodings:
+ raise ValueError(
+ "word_ids() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`"
+ " class)."
+ )
+ return self._encodings[batch_index].word_ids
+
+ def token_to_sequence(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int:
+ """
+ Get the index of the sequence represented by the given token. In the general use case, this method returns `0`
+ for a single sequence or the first sequence of a pair, and `1` for the second sequence of a pair
+
+ Can be called as:
+
+ - `self.token_to_sequence(token_index)` if batch size is 1
+ - `self.token_to_sequence(batch_index, token_index)` if batch size is greater than 1
+
+ This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e.,
+ words are defined by the user). In this case it allows to easily associate encoded tokens with provided
+ tokenized words.
+
+ Args:
+ batch_or_token_index (`int`):
+ Index of the sequence in the batch. If the batch only comprises one sequence, this can be the index of
+ the token in the sequence.
+ token_index (`int`, *optional*):
+ If a batch index is provided in *batch_or_token_index*, this can be the index of the token in the
+ sequence.
+
+ Returns:
+ `int`: Index of the word in the input sequence.
+ """
+
+ if not self._encodings:
+ raise ValueError("token_to_sequence() is not available when using Python based tokenizers")
+ if token_index is not None:
+ batch_index = batch_or_token_index
+ else:
+ batch_index = 0
+ token_index = batch_or_token_index
+ if batch_index < 0:
+ batch_index = self._batch_size + batch_index
+ if token_index < 0:
+ token_index = self._seq_len + token_index
+ return self._encodings[batch_index].token_to_sequence(token_index)
+
+ def token_to_word(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int:
+ """
+ Get the index of the word corresponding (i.e. comprising) to an encoded token in a sequence of the batch.
+
+ Can be called as:
+
+ - `self.token_to_word(token_index)` if batch size is 1
+ - `self.token_to_word(batch_index, token_index)` if batch size is greater than 1
+
+ This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e.,
+ words are defined by the user). In this case it allows to easily associate encoded tokens with provided
+ tokenized words.
+
+ Args:
+ batch_or_token_index (`int`):
+ Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of
+ the token in the sequence.
+ token_index (`int`, *optional*):
+ If a batch index is provided in *batch_or_token_index*, this can be the index of the token in the
+ sequence.
+
+ Returns:
+ `int`: Index of the word in the input sequence.
+ """
+
+ if not self._encodings:
+ raise ValueError("token_to_word() is not available when using Python based tokenizers")
+ if token_index is not None:
+ batch_index = batch_or_token_index
+ else:
+ batch_index = 0
+ token_index = batch_or_token_index
+ if batch_index < 0:
+ batch_index = self._batch_size + batch_index
+ if token_index < 0:
+ token_index = self._seq_len + token_index
+ return self._encodings[batch_index].token_to_word(token_index)
+
+ def word_to_tokens(
+ self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0
+ ) -> Optional[TokenSpan]:
+ """
+ Get the encoded token span corresponding to a word in a sequence of the batch.
+
+ Token spans are returned as a [`~tokenization_utils_base.TokenSpan`] with:
+
+ - **start** -- Index of the first token.
+ - **end** -- Index of the token following the last token.
+
+ Can be called as:
+
+ - `self.word_to_tokens(word_index, sequence_index: int = 0)` if batch size is 1
+ - `self.word_to_tokens(batch_index, word_index, sequence_index: int = 0)` if batch size is greater or equal to
+ 1
+
+ This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words
+ are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized
+ words.
+
+ Args:
+ batch_or_word_index (`int`):
+ Index of the sequence in the batch. If the batch only comprises one sequence, this can be the index of
+ the word in the sequence.
+ word_index (`int`, *optional*):
+ If a batch index is provided in *batch_or_token_index*, this can be the index of the word in the
+ sequence.
+ sequence_index (`int`, *optional*, defaults to 0):
+ If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0
+ or 1) the provided word index belongs to.
+
+ Returns:
+ ([`~tokenization_utils_base.TokenSpan`], *optional*): Span of tokens in the encoded sequence. Returns
+ `None` if no tokens correspond to the word. This can happen especially when the token is a special token
+ that has been used to format the tokenization. For example when we add a class token at the very beginning
+ of the tokenization.
+ """
+
+ if not self._encodings:
+ raise ValueError("word_to_tokens() is not available when using Python based tokenizers")
+ if word_index is not None:
+ batch_index = batch_or_word_index
+ else:
+ batch_index = 0
+ word_index = batch_or_word_index
+ if batch_index < 0:
+ batch_index = self._batch_size + batch_index
+ if word_index < 0:
+ word_index = self._seq_len + word_index
+ span = self._encodings[batch_index].word_to_tokens(word_index, sequence_index)
+ return TokenSpan(*span) if span is not None else None
+
+ def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> Optional[CharSpan]:
+ """
+ Get the character span corresponding to an encoded token in a sequence of the batch.
+
+ Character spans are returned as a [`~tokenization_utils_base.CharSpan`] with:
+
+ - **start** -- Index of the first character in the original string associated to the token.
+ - **end** -- Index of the character following the last character in the original string associated to the
+ token.
+
+ Can be called as:
+
+ - `self.token_to_chars(token_index)` if batch size is 1
+ - `self.token_to_chars(batch_index, token_index)` if batch size is greater or equal to 1
+
+ Args:
+ batch_or_token_index (`int`):
+ Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of
+ the token in the sequence.
+ token_index (`int`, *optional*):
+ If a batch index is provided in *batch_or_token_index*, this can be the index of the token or tokens in
+ the sequence.
+
+ Returns:
+ [`~tokenization_utils_base.CharSpan`]: Span of characters in the original string, or None, if the token
+ (e.g. , ) doesn't correspond to any chars in the origin string.
+ """
+
+ if not self._encodings:
+ raise ValueError("token_to_chars() is not available when using Python based tokenizers")
+ if token_index is not None:
+ batch_index = batch_or_token_index
+ else:
+ batch_index = 0
+ token_index = batch_or_token_index
+ span_indices = self._encodings[batch_index].token_to_chars(token_index)
+
+ return CharSpan(*span_indices) if span_indices is not None else None
+
+ def char_to_token(
+ self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0
+ ) -> int:
+ """
+ Get the index of the token in the encoded output comprising a character in the original string for a sequence
+ of the batch.
+
+ Can be called as:
+
+ - `self.char_to_token(char_index)` if batch size is 1
+ - `self.char_to_token(batch_index, char_index)` if batch size is greater or equal to 1
+
+ This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words
+ are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized
+ words.
+
+ Args:
+ batch_or_char_index (`int`):
+ Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of
+ the word in the sequence
+ char_index (`int`, *optional*):
+ If a batch index is provided in *batch_or_token_index*, this can be the index of the word in the
+ sequence.
+ sequence_index (`int`, *optional*, defaults to 0):
+ If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0
+ or 1) the provided character index belongs to.
+
+
+ Returns:
+ `int`: Index of the token, or None if the char index refers to a whitespace only token and whitespace is
+ trimmed with `trim_offsets=True`.
+ """
+
+ if not self._encodings:
+ raise ValueError("char_to_token() is not available when using Python based tokenizers")
+ if char_index is not None:
+ batch_index = batch_or_char_index
+ else:
+ batch_index = 0
+ char_index = batch_or_char_index
+ return self._encodings[batch_index].char_to_token(char_index, sequence_index)
+
+ def word_to_chars(
+ self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0
+ ) -> CharSpan:
+ """
+ Get the character span in the original string corresponding to given word in a sequence of the batch.
+
+ Character spans are returned as a CharSpan NamedTuple with:
+
+ - start: index of the first character in the original string
+ - end: index of the character following the last character in the original string
+
+ Can be called as:
+
+ - `self.word_to_chars(word_index)` if batch size is 1
+ - `self.word_to_chars(batch_index, word_index)` if batch size is greater or equal to 1
+
+ Args:
+ batch_or_word_index (`int`):
+ Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of
+ the word in the sequence
+ word_index (`int`, *optional*):
+ If a batch index is provided in *batch_or_token_index*, this can be the index of the word in the
+ sequence.
+ sequence_index (`int`, *optional*, defaults to 0):
+ If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0
+ or 1) the provided word index belongs to.
+
+ Returns:
+ `CharSpan` or `list[CharSpan]`: Span(s) of the associated character or characters in the string. CharSpan
+ are NamedTuple with:
+
+ - start: index of the first character associated to the token in the original string
+ - end: index of the character following the last character associated to the token in the original
+ string
+ """
+
+ if not self._encodings:
+ raise ValueError("word_to_chars() is not available when using Python based tokenizers")
+ if word_index is not None:
+ batch_index = batch_or_word_index
+ else:
+ batch_index = 0
+ word_index = batch_or_word_index
+ return CharSpan(*(self._encodings[batch_index].word_to_chars(word_index, sequence_index)))
+
+ def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0) -> int:
+ """
+ Get the word in the original string corresponding to a character in the original string of a sequence of the
+ batch.
+
+ Can be called as:
+
+ - `self.char_to_word(char_index)` if batch size is 1
+ - `self.char_to_word(batch_index, char_index)` if batch size is greater than 1
+
+ This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words
+ are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized
+ words.
+
+ Args:
+ batch_or_char_index (`int`):
+ Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of
+ the character in the original string.
+ char_index (`int`, *optional*):
+ If a batch index is provided in *batch_or_token_index*, this can be the index of the character in the
+ original string.
+ sequence_index (`int`, *optional*, defaults to 0):
+ If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0
+ or 1) the provided character index belongs to.
+
+
+ Returns:
+ `int` or `list[int]`: Index or indices of the associated encoded token(s).
+ """
+
+ if not self._encodings:
+ raise ValueError("char_to_word() is not available when using Python based tokenizers")
+ if char_index is not None:
+ batch_index = batch_or_char_index
+ else:
+ batch_index = 0
+ char_index = batch_or_char_index
+ return self._encodings[batch_index].char_to_word(char_index, sequence_index)
+
+ def convert_to_tensors(
+ self, tensor_type: Optional[Union[str, TensorType]] = None, prepend_batch_axis: bool = False
+ ):
+ """
+ Convert the inner content to tensors.
+
+ Args:
+ tensor_type (`str` or [`~utils.TensorType`], *optional*):
+ The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
+ `None`, no modification is done.
+ prepend_batch_axis (`int`, *optional*, defaults to `False`):
+ Whether or not to add the batch dimension during the conversion.
+ """
+ if tensor_type is None:
+ return self
+
+ # Convert to TensorType
+ if not isinstance(tensor_type, TensorType):
+ tensor_type = TensorType(tensor_type)
+
+ # Get a function reference for the correct framework
+ if tensor_type == TensorType.TENSORFLOW:
+ if not is_tf_available():
+ raise ImportError(
+ "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
+ )
+ import tensorflow as tf
+
+ def as_tensor(value, dtype=None):
+ if len(flatten(value)) == 0 and dtype is None:
+ dtype = tf.int32
+ return tf.constant(value, dtype=dtype)
+
+ is_tensor = tf.is_tensor
+
+ elif tensor_type == TensorType.PYTORCH:
+ if not is_torch_available():
+ raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
+ import torch
+
+ def as_tensor(value, dtype=None):
+ if isinstance(value, list) and len(value) > 0 and isinstance(value[0], np.ndarray):
+ return torch.from_numpy(np.array(value))
+ if len(flatten(value)) == 0 and dtype is None:
+ dtype = torch.int64
+ return torch.tensor(value, dtype=dtype)
+
+ is_tensor = torch.is_tensor
+
+ elif tensor_type == TensorType.JAX:
+ if not is_flax_available():
+ raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
+ import jax.numpy as jnp # noqa: F811
+
+ def as_tensor(value, dtype=None):
+ if len(flatten(value)) == 0 and dtype is None:
+ dtype = jnp.int32
+ return jnp.array(value, dtype=dtype)
+
+ is_tensor = is_jax_tensor
+
+ elif tensor_type == TensorType.MLX:
+ if not is_mlx_available():
+ raise ImportError("Unable to convert output to MLX tensors format, MLX is not installed.")
+ import mlx.core as mx
+
+ def as_tensor(value, dtype=None):
+ if len(flatten(value)) == 0 and dtype is None:
+ dtype = mx.int32
+ return mx.array(value, dtype=dtype)
+
+ def is_tensor(obj):
+ return isinstance(obj, mx.array)
+ else:
+
+ def as_tensor(value, dtype=None):
+ if (
+ isinstance(value, (list, tuple))
+ and len(value) > 0
+ and isinstance(value[0], (list, tuple, np.ndarray))
+ ):
+ value_lens = [len(val) for val in value]
+ if len(set(value_lens)) > 1 and dtype is None:
+ # we have a ragged list so handle explicitly
+ value = as_tensor([np.asarray(val) for val in value], dtype=object)
+ if len(flatten(value)) == 0 and dtype is None:
+ dtype = np.int64
+ return np.asarray(value, dtype=dtype)
+
+ is_tensor = is_numpy_array
+
+ # Do the tensor conversion in batch
+ for key, value in self.items():
+ try:
+ if prepend_batch_axis:
+ value = [value]
+
+ if not is_tensor(value):
+ tensor = as_tensor(value)
+
+ # Removing this for now in favor of controlling the shape with `prepend_batch_axis`
+ # # at-least2d
+ # if tensor.ndim > 2:
+ # tensor = tensor.squeeze(0)
+ # elif tensor.ndim < 2:
+ # tensor = tensor[None, :]
+
+ self[key] = tensor
+ except Exception as e:
+ if key == "overflowing_tokens":
+ raise ValueError(
+ "Unable to create tensor returning overflowing tokens of different lengths. "
+ "Please see if a fast version of this tokenizer is available to have this feature available."
+ ) from e
+ raise ValueError(
+ "Unable to create tensor, you should probably activate truncation and/or padding with"
+ " 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your"
+ f" features (`{key}` in this case) have excessive nesting (inputs type `list` where type `int` is"
+ " expected)."
+ ) from e
+
+ return self
+
+ def to(self, device: Union[str, "torch.device"], *, non_blocking: bool = False) -> "BatchEncoding":
+ """
+ Send all values to device by calling `v.to(device, non_blocking=non_blocking)` (PyTorch only).
+
+ Args:
+ device (`str` or `torch.device`): The device to put the tensors on.
+ non_blocking (`bool`): Whether to perform the copy asynchronously.
+
+ Returns:
+ [`BatchEncoding`]: The same instance after modification.
+ """
+ requires_backends(self, ["torch"])
+
+ # This check catches things like APEX blindly calling "to" on all inputs to a module
+ # Otherwise it passes the casts down and casts the LongTensor containing the token idxs
+ # into a HalfTensor
+ if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
+ self.data = {
+ k: v.to(device=device, non_blocking=non_blocking) if hasattr(v, "to") and callable(v.to) else v
+ for k, v in self.data.items()
+ }
+ else:
+ logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
+ return self
+
+
+class SpecialTokensMixin:
+ """
+ A mixin derived by [`PreTrainedTokenizer`] and [`PreTrainedTokenizerFast`] to handle specific behaviors related to
+ special tokens. In particular, this class hold the attributes which can be used to directly access these special
+ tokens in a model-independent manner and allow to set and update the special tokens.
+
+ Args:
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*):
+ A special token representing the beginning of a sentence.
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*):
+ A special token representing the end of a sentence.
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*):
+ A special token representing an out-of-vocabulary token.
+ sep_token (`str` or `tokenizers.AddedToken`, *optional*):
+ A special token separating two different sentences in the same input (used by BERT for instance).
+ pad_token (`str` or `tokenizers.AddedToken`, *optional*):
+ A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
+ attention mechanisms or loss computation.
+ cls_token (`str` or `tokenizers.AddedToken`, *optional*):
+ A special token representing the class of the input (used by BERT for instance).
+ mask_token (`str` or `tokenizers.AddedToken`, *optional*):
+ A special token representing a masked token (used by masked-language modeling pretraining objectives, like
+ BERT).
+ additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*):
+ A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be
+ skipped when decoding if `skip_special_tokens` is set to `True`.
+ """
+
+ SPECIAL_TOKENS_ATTRIBUTES = [
+ "bos_token",
+ "eos_token",
+ "unk_token",
+ "sep_token",
+ "pad_token",
+ "cls_token",
+ "mask_token",
+ "additional_special_tokens",
+ ]
+
+ def __init__(self, verbose=False, **kwargs):
+ self._pad_token_type_id = 0
+ self.verbose = verbose
+ self._special_tokens_map = dict.fromkeys(self.SPECIAL_TOKENS_ATTRIBUTES)
+ self._special_tokens_map["additional_special_tokens"] = [] # for BC where it defaults to empty list
+
+ # We directly set the hidden value to allow initialization with special tokens
+ # which are not yet in the vocabulary. Necessary for serialization/de-serialization
+ # TODO clean this up at some point (probably by switching to fast tokenizers)
+
+ for key, value in kwargs.items():
+ if value is None:
+ continue
+ if key in self.SPECIAL_TOKENS_ATTRIBUTES:
+ if key == "additional_special_tokens":
+ assert isinstance(value, (list, tuple)), f"Value {value} is not a list or tuple"
+ assert all(isinstance(t, (str, AddedToken)) for t in value), (
+ "One of the tokens is not a string or an AddedToken"
+ )
+ setattr(self, key, value)
+ elif isinstance(value, (str, AddedToken)):
+ setattr(self, key, value)
+ else:
+ raise TypeError(f"Special token {key} has to be either str or AddedToken but got: {type(value)}")
+
+ def sanitize_special_tokens(self) -> int:
+ """
+ The `sanitize_special_tokens` is now deprecated kept for backward compatibility and will be removed in
+ transformers v5.
+ """
+ logger.warning_once("The `sanitize_special_tokens` will be removed in transformers v5.")
+ return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
+
+ def add_special_tokens(
+ self,
+ special_tokens_dict: dict[str, Union[str, AddedToken, Sequence[Union[str, AddedToken]]]],
+ replace_additional_special_tokens=True,
+ ) -> int:
+ """
+ Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If
+ special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the
+ current vocabulary).
+
+ When adding new tokens to the vocabulary, you should make sure to also resize the token embedding matrix of the
+ model so that its embedding matrix matches the tokenizer.
+
+ In order to do that, please use the [`~PreTrainedModel.resize_token_embeddings`] method.
+
+ Using `add_special_tokens` will ensure your special tokens can be used in several ways:
+
+ - Special tokens can be skipped when decoding using `skip_special_tokens = True`.
+ - Special tokens are carefully handled by the tokenizer (they are never split), similar to `AddedTokens`.
+ - You can easily refer to special tokens using tokenizer class attributes like `tokenizer.cls_token`. This
+ makes it easy to develop model-agnostic training and fine-tuning scripts.
+
+ When possible, special tokens are already registered for provided pretrained models (for instance
+ [`BertTokenizer`] `cls_token` is already registered to be `'[CLS]'` and XLM's one is also registered to be
+ `''`).
+
+ Args:
+ special_tokens_dict (dictionary *str* to *str*, `tokenizers.AddedToken`, or `Sequence[Union[str, AddedToken]]`):
+ Keys should be in the list of predefined special attributes: [`bos_token`, `eos_token`, `unk_token`,
+ `sep_token`, `pad_token`, `cls_token`, `mask_token`, `additional_special_tokens`].
+
+ Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer
+ assign the index of the `unk_token` to them).
+ replace_additional_special_tokens (`bool`, *optional*,, defaults to `True`):
+ If `True`, the existing list of additional special tokens will be replaced by the list provided in
+ `special_tokens_dict`. Otherwise, `self._special_tokens_map["additional_special_tokens"]` is just extended. In the former
+ case, the tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged
+ as non-special tokens. Remember, this only affects which tokens are skipped during decoding, not the
+ `added_tokens_encoder` and `added_tokens_decoder`. This means that the previous
+ `additional_special_tokens` are still added tokens, and will not be split by the model.
+
+ Returns:
+ `int`: Number of tokens added to the vocabulary.
+
+ Examples:
+
+ ```python
+ # Let's see how to add a new classification token to GPT-2
+ tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
+ model = GPT2Model.from_pretrained("openai-community/gpt2")
+
+ special_tokens_dict = {"cls_token": ""}
+
+ num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
+ print("We have added", num_added_toks, "tokens")
+ # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
+ model.resize_token_embeddings(len(tokenizer))
+
+ assert tokenizer.cls_token == ""
+ ```"""
+ if not special_tokens_dict:
+ return 0
+
+ added_tokens = []
+ for key, value in special_tokens_dict.items():
+ assert key in self.SPECIAL_TOKENS_ATTRIBUTES, f"Key {key} is not a special token"
+
+ if self.verbose:
+ logger.info(f"Assigning {value} to the {key} key of the tokenizer")
+
+ if key == "additional_special_tokens":
+ assert isinstance(value, (list, tuple)) and all(isinstance(t, (str, AddedToken)) for t in value), (
+ f"Tokens {value} for key {key} should all be str or AddedToken instances"
+ )
+
+ to_add = []
+ for token in value:
+ if isinstance(token, str):
+ # for legacy purpose we default to stripping. `test_add_tokens_tokenizer` depends on this
+ token = AddedToken(token, rstrip=False, lstrip=False, normalized=False, special=True)
+ if not replace_additional_special_tokens and str(token) in self.additional_special_tokens:
+ continue
+ to_add.append(token)
+ if replace_additional_special_tokens and len(to_add) > 0:
+ setattr(self, key, list(to_add))
+ else:
+ self._special_tokens_map["additional_special_tokens"].extend(to_add)
+ added_tokens += to_add
+
+ else:
+ if not isinstance(value, (str, AddedToken)):
+ raise ValueError(f"Token {value} for key {key} should be a str or an AddedToken instance")
+ if isinstance(value, (str)):
+ # for legacy purpose we default to stripping. `False` depends on this
+ value = AddedToken(value, rstrip=False, lstrip=False, normalized=False, special=True)
+ if isinstance(value, AddedToken):
+ setattr(self, key, value)
+ if value not in added_tokens:
+ added_tokens.append(value)
+
+ # if we are adding tokens that were not part of the vocab, we ought to add them
+ added_tokens = self.add_tokens(added_tokens, special_tokens=True)
+ return added_tokens
+
+ def add_tokens(
+ self, new_tokens: Union[str, AddedToken, Sequence[Union[str, AddedToken]]], special_tokens: bool = False
+ ) -> int:
+ """
+ Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to
+ it with indices starting from length of the current vocabulary and will be isolated before the tokenization
+ algorithm is applied. Added tokens and tokens from the vocabulary of the tokenization algorithm are therefore
+ not treated in the same way.
+
+ Note, when adding new tokens to the vocabulary, you should make sure to also resize the token embedding matrix
+ of the model so that its embedding matrix matches the tokenizer.
+
+ In order to do that, please use the [`~PreTrainedModel.resize_token_embeddings`] method.
+
+ Args:
+ new_tokens (`str`, `tokenizers.AddedToken` or a sequence of *str* or `tokenizers.AddedToken`):
+ Tokens are only added if they are not already in the vocabulary. `tokenizers.AddedToken` wraps a string
+ token to let you personalize its behavior: whether this token should only match against a single word,
+ whether this token should strip all potential whitespaces on the left side, whether this token should
+ strip all potential whitespaces on the right side, etc.
+ special_tokens (`bool`, *optional*, defaults to `False`):
+ Can be used to specify if the token is a special token. This mostly change the normalization behavior
+ (special tokens like CLS or [MASK] are usually not lower-cased for instance).
+
+ See details for `tokenizers.AddedToken` in HuggingFace tokenizers library.
+
+ Returns:
+ `int`: Number of tokens added to the vocabulary.
+
+ Examples:
+
+ ```python
+ # Let's see how to increase the vocabulary of Bert model and tokenizer
+ tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-uncased")
+ model = BertModel.from_pretrained("google-bert/bert-base-uncased")
+
+ num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"])
+ print("We have added", num_added_toks, "tokens")
+ # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
+ model.resize_token_embeddings(len(tokenizer))
+ ```"""
+ if not new_tokens:
+ return 0
+
+ if not isinstance(new_tokens, (list, tuple)):
+ new_tokens = [new_tokens]
+
+ return self._add_tokens(new_tokens, special_tokens=special_tokens)
+
+ def _add_tokens(self, new_tokens: Union[list[str], list[AddedToken]], special_tokens: bool = False) -> int:
+ raise NotImplementedError
+
+ @property
+ def pad_token_type_id(self) -> int:
+ """
+ `int`: Id of the padding token type in the vocabulary.
+ """
+ return self._pad_token_type_id
+
+ def __setattr__(self, key, value):
+ key_without_id = key
+ key_is_special_id = key.endswith("_id") or key.endswith("_ids")
+ if key_is_special_id:
+ key_without_id = key[:-3] if not key.endswith("_ids") else key[:-4]
+
+ if self.__dict__.get("_special_tokens_map", None) is not None and any(
+ name in self.__dict__["_special_tokens_map"] for name in [key, key_without_id]
+ ):
+ if key_is_special_id:
+ if value is not None:
+ value = (
+ self.convert_ids_to_tokens(value)
+ if key != "additional_special_tokens"
+ else [self.convert_ids_to_tokens(val) for val in value]
+ )
+ key = key_without_id
+
+ if key != "additional_special_tokens" and not isinstance(value, (str, AddedToken)) and value is not None:
+ raise ValueError(f"Cannot set a non-string value as the {key}")
+ self._special_tokens_map[key] = value
+ else:
+ super().__setattr__(key, value)
+
+ def __getattr__(self, key):
+ key_without_id = key
+ key_is_special_id = key.endswith("_id") or key.endswith("_ids")
+ if key_is_special_id:
+ key_without_id = key[:-3] if not key.endswith("_ids") else key[:-4]
+
+ if self.__dict__.get("_special_tokens_map", None) is not None and any(
+ name in self.__dict__["_special_tokens_map"] for name in [key, key_without_id]
+ ):
+ _special_tokens_map = self.__dict__["_special_tokens_map"]
+ if not key_is_special_id:
+ if _special_tokens_map[key] is None:
+ if self.verbose:
+ logger.error(f"Using {key}, but it is not set yet.")
+ return None
+ value = _special_tokens_map[key]
+ return str(value) if key != "additional_special_tokens" else [str(tok) for tok in value]
+ else:
+ attr_as_tokens = getattr(self, key_without_id)
+ return self.convert_tokens_to_ids(attr_as_tokens) if attr_as_tokens is not None else None
+
+ if key not in self.__dict__:
+ raise AttributeError(f"{self.__class__.__name__} has no attribute {key}")
+ else:
+ return super().__getattr__(key)
+
+ @property
+ def special_tokens_map(self) -> dict[str, Union[str, list[str]]]:
+ """
+ `dict[str, Union[str, list[str]]]`: A dictionary mapping special token class attributes (`cls_token`,
+ `unk_token`, etc.) to their values (`''`, `''`, etc.).
+
+ Convert potential tokens of `tokenizers.AddedToken` type to string.
+ """
+ set_attr = {}
+ for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
+ attr_value = getattr(self, attr)
+ if attr_value:
+ set_attr[attr] = attr_value
+ return set_attr
+
+ @property
+ def special_tokens_map_extended(self) -> dict[str, Union[str, AddedToken, list[Union[str, AddedToken]]]]:
+ """
+ `dict[str, Union[str, tokenizers.AddedToken, list[Union[str, tokenizers.AddedToken]]]]`: A dictionary mapping
+ special token class attributes (`cls_token`, `unk_token`, etc.) to their values (`''`, `''`, etc.).
+
+ Don't convert tokens of `tokenizers.AddedToken` type to string so they can be used to control more finely how
+ special tokens are tokenized.
+ """
+ set_attr = {}
+ for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
+ attr_value = self._special_tokens_map[attr]
+ if attr_value:
+ set_attr[attr] = attr_value
+ return set_attr
+
+ @property
+ def all_special_tokens_extended(self) -> list[Union[str, AddedToken]]:
+ """
+ `list[Union[str, tokenizers.AddedToken]]`: All the special tokens (`''`, `''`, etc.), the order has
+ nothing to do with the index of each tokens. If you want to know the correct indices, check
+ `self.added_tokens_encoder`. We can't create an order anymore as the keys are `AddedTokens` and not `Strings`.
+
+ Don't convert tokens of `tokenizers.AddedToken` type to string so they can be used to control more finely how
+ special tokens are tokenized.
+ """
+ all_tokens = []
+ seen = set()
+ for value in self.special_tokens_map_extended.values():
+ if isinstance(value, (list, tuple)):
+ tokens_to_add = [token for token in value if str(token) not in seen]
+ else:
+ tokens_to_add = [value] if str(value) not in seen else []
+ seen.update(map(str, tokens_to_add))
+ all_tokens.extend(tokens_to_add)
+ return all_tokens
+
+ @property
+ def all_special_tokens(self) -> list[str]:
+ """
+ `list[str]`: A list of the unique special tokens (`''`, `''`, ..., etc.).
+
+ Convert tokens of `tokenizers.AddedToken` type to string.
+ """
+ all_toks = [str(s) for s in self.all_special_tokens_extended]
+ return all_toks
+
+ @property
+ def all_special_ids(self) -> list[int]:
+ """
+ `list[int]`: List the ids of the special tokens(`''`, `''`, etc.) mapped to class attributes.
+ """
+ all_toks = self.all_special_tokens
+ all_ids = self.convert_tokens_to_ids(all_toks)
+ return all_ids
+
+ def _set_model_specific_special_tokens(self, special_tokens: list[str]):
+ """
+ Adds new special tokens to the "SPECIAL_TOKENS_ATTRIBUTES" list which will be part
+ of "self.special_tokens" and saved as a special token in tokenizer's config.
+ This allows us to dynamically add new model-type specific tokens after initializing the tokenizer.
+ For example: if the model tokenizers is multimodal, we can support special image or audio tokens.
+ """
+ self.SPECIAL_TOKENS_ATTRIBUTES = self.SPECIAL_TOKENS_ATTRIBUTES + list(special_tokens.keys())
+ for key, value in special_tokens.items():
+ if isinstance(value, (str, AddedToken)):
+ self._special_tokens_map[key] = value
+ else:
+ raise TypeError(f"Special token {key} has to be either str or AddedToken but got: {type(value)}")
+
+
+ENCODE_KWARGS_DOCSTRING = r"""
+ add_special_tokens (`bool`, *optional*, defaults to `True`):
+ Whether or not to add special tokens when encoding the sequences. This will use the underlying
+ `PretrainedTokenizerBase.build_inputs_with_special_tokens` function, which defines which tokens are
+ automatically added to the input ids. This is useful if you want to add `bos` or `eos` tokens
+ automatically.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Activates and controls padding. Accepts the following values:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence is provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
+ Activates and controls truncation. Accepts the following values:
+
+ - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
+ to the maximum acceptable input length for the model if that argument is not provided. This will
+ truncate token by token, removing a token from the longest sequence in the pair if a pair of
+ sequences (or a batch of pairs) is provided.
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
+ greater than the model maximum admissible input size).
+ max_length (`int`, *optional*):
+ Controls the maximum length to use by one of the truncation/padding parameters.
+
+ If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
+ is required by one of the truncation/padding parameters. If the model has no specific maximum input
+ length (like XLNet) truncation/padding to a maximum length will be deactivated.
+ stride (`int`, *optional*, defaults to 0):
+ If set to a number along with `max_length`, the overflowing tokens returned when
+ `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
+ returned to provide some overlap between truncated and overflowing sequences. The value of this
+ argument defines the number of overlapping tokens.
+ is_split_into_words (`bool`, *optional*, defaults to `False`):
+ Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
+ tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
+ which it will tokenize. This is useful for NER or token classification.
+ pad_to_multiple_of (`int`, *optional*):
+ If set will pad the sequence to a multiple of the provided value. Requires `padding` to be activated.
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
+ `>= 7.5` (Volta).
+ padding_side (`str`, *optional*):
+ The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+ Default value is picked from the class attribute of the same name.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+"""
+
+ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
+ return_token_type_ids (`bool`, *optional*):
+ Whether to return token type IDs. If left to the default, will return the token type IDs according to
+ the specific tokenizer's default, defined by the `return_outputs` attribute.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ return_attention_mask (`bool`, *optional*):
+ Whether to return the attention mask. If left to the default, will return the attention mask according
+ to the specific tokenizer's default, defined by the `return_outputs` attribute.
+
+ [What are attention masks?](../glossary#attention-mask)
+ return_overflowing_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch
+ of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead
+ of returning overflowing tokens.
+ return_special_tokens_mask (`bool`, *optional*, defaults to `False`):
+ Whether or not to return special tokens mask information.
+ return_offsets_mapping (`bool`, *optional*, defaults to `False`):
+ Whether or not to return `(char_start, char_end)` for each token.
+
+ This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using
+ Python's tokenizer, this method will raise `NotImplementedError`.
+ return_length (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the lengths of the encoded inputs.
+ verbose (`bool`, *optional*, defaults to `True`):
+ Whether or not to print more information and warnings.
+ **kwargs: passed to the `self.tokenize()` method
+
+ Return:
+ [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model.
+
+ [What are input IDs?](../glossary#input-ids)
+
+ - **token_type_ids** -- List of token type ids to be fed to a model (when `return_token_type_ids=True` or
+ if *"token_type_ids"* is in `self.model_input_names`).
+
+ [What are token type IDs?](../glossary#token-type-ids)
+
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`).
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ - **overflowing_tokens** -- List of overflowing tokens sequences (when a `max_length` is specified and
+ `return_overflowing_tokens=True`).
+ - **num_truncated_tokens** -- Number of tokens truncated (when a `max_length` is specified and
+ `return_overflowing_tokens=True`).
+ - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying
+ regular sequence tokens (when `add_special_tokens=True` and `return_special_tokens_mask=True`).
+ - **length** -- The length of the inputs (when `return_length=True`)
+"""
+
+
+INIT_TOKENIZER_DOCSTRING = r"""
+ Class attributes (overridden by derived classes)
+
+ - **vocab_files_names** (`dict[str, str]`) -- A dictionary with, as keys, the `__init__` keyword name of each
+ vocabulary file required by the model, and as associated values, the filename for saving the associated file
+ (string).
+ - **pretrained_vocab_files_map** (`dict[str, dict[str, str]]`) -- A dictionary of dictionaries, with the
+ high-level keys being the `__init__` keyword name of each vocabulary file required by the model, the
+ low-level being the `short-cut-names` of the pretrained models with, as associated values, the `url` to the
+ associated pretrained vocabulary file.
+ - **model_input_names** (`list[str]`) -- A list of inputs expected in the forward pass of the model.
+ - **padding_side** (`str`) -- The default value for the side on which the model should have padding applied.
+ Should be `'right'` or `'left'`.
+ - **truncation_side** (`str`) -- The default value for the side on which the model should have truncation
+ applied. Should be `'right'` or `'left'`.
+
+ Args:
+ model_max_length (`int`, *optional*):
+ The maximum length (in number of tokens) for the inputs to the transformer model. When the tokenizer is
+ loaded with [`~tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`], this will be set to the
+ value stored for the associated model in `max_model_input_sizes` (see above). If no value is provided, will
+ default to VERY_LARGE_INTEGER (`int(1e30)`).
+ padding_side (`str`, *optional*):
+ The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+ Default value is picked from the class attribute of the same name.
+ truncation_side (`str`, *optional*):
+ The side on which the model should have truncation applied. Should be selected between ['right', 'left'].
+ Default value is picked from the class attribute of the same name.
+ chat_template (`str`, *optional*):
+ A Jinja template string that will be used to format lists of chat messages. See
+ https://huggingface.co/docs/transformers/chat_templating for a full description.
+ model_input_names (`list[string]`, *optional*):
+ The list of inputs accepted by the forward pass of the model (like `"token_type_ids"` or
+ `"attention_mask"`). Default value is picked from the class attribute of the same name.
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*):
+ A special token representing the beginning of a sentence. Will be associated to `self.bos_token` and
+ `self.bos_token_id`.
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*):
+ A special token representing the end of a sentence. Will be associated to `self.eos_token` and
+ `self.eos_token_id`.
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*):
+ A special token representing an out-of-vocabulary token. Will be associated to `self.unk_token` and
+ `self.unk_token_id`.
+ sep_token (`str` or `tokenizers.AddedToken`, *optional*):
+ A special token separating two different sentences in the same input (used by BERT for instance). Will be
+ associated to `self.sep_token` and `self.sep_token_id`.
+ pad_token (`str` or `tokenizers.AddedToken`, *optional*):
+ A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
+ attention mechanisms or loss computation. Will be associated to `self.pad_token` and `self.pad_token_id`.
+ cls_token (`str` or `tokenizers.AddedToken`, *optional*):
+ A special token representing the class of the input (used by BERT for instance). Will be associated to
+ `self.cls_token` and `self.cls_token_id`.
+ mask_token (`str` or `tokenizers.AddedToken`, *optional*):
+ A special token representing a masked token (used by masked-language modeling pretraining objectives, like
+ BERT). Will be associated to `self.mask_token` and `self.mask_token_id`.
+ additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*):
+ A tuple or a list of additional special tokens. Add them here to ensure they are skipped when decoding with
+ `skip_special_tokens` is set to True. If they are not part of the vocabulary, they will be added at the end
+ of the vocabulary.
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should cleanup the spaces that were added when splitting the input text during the
+ tokenization process.
+ split_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the special tokens should be split during the tokenization process. Passing will affect the
+ internal state of the tokenizer. The default behavior is to not split special tokens. This means that if
+ `` is the `bos_token`, then `tokenizer.tokenize("") = ['`]. Otherwise, if
+ `split_special_tokens=True`, then `tokenizer.tokenize("")` will be give `['<','s', '>']`.
+"""
+
+
+@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
+class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
+ """
+ Base class for [`PreTrainedTokenizer`] and [`PreTrainedTokenizerFast`].
+
+ Handles shared (mostly boiler plate) methods for those two classes.
+ """
+
+ vocab_files_names: dict[str, str] = {}
+ pretrained_vocab_files_map: dict[str, dict[str, str]] = {}
+ _auto_class: Optional[str] = None
+
+ # first name has to correspond to main model input name
+ # to make sure `tokenizer.pad(...)` works correctly
+ model_input_names: list[str] = ["input_ids", "token_type_ids", "attention_mask"]
+ padding_side: str = "right"
+ truncation_side: str = "right"
+ slow_tokenizer_class = None
+
+ def __init__(self, **kwargs):
+ # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
+ self.init_inputs = ()
+ for key in kwargs:
+ if hasattr(self, key) and callable(getattr(self, key)):
+ raise AttributeError(f"{key} conflicts with the method {key} in {self.__class__.__name__}")
+
+ self.init_kwargs = copy.deepcopy(kwargs)
+ self.name_or_path = kwargs.pop("name_or_path", "")
+ self._processor_class = kwargs.pop("processor_class", None)
+
+ # For backward compatibility we fallback to set model_max_length from max_len if provided
+ model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None))
+ self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER
+
+ # Padding and truncation side are right by default and overridden in subclasses. If specified in the kwargs, it
+ # is changed.
+ self.padding_side = kwargs.pop("padding_side", self.padding_side)
+ if self.padding_side not in ["right", "left"]:
+ raise ValueError(
+ f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}"
+ )
+
+ self.truncation_side = kwargs.pop("truncation_side", self.truncation_side)
+ if self.truncation_side not in ["right", "left"]:
+ raise ValueError(
+ f"Truncation side should be selected between 'right' and 'left', current value: {self.truncation_side}"
+ )
+
+ self.model_input_names = kwargs.pop("model_input_names", self.model_input_names)
+
+ # By default, cleaning tokenization spaces for both fast and slow tokenizers
+ self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", False)
+
+ # By default, do not split special tokens for both fast and slow tokenizers
+ self.split_special_tokens = kwargs.pop("split_special_tokens", False)
+
+ self.deprecation_warnings = {} # Use to store when we have already noticed a deprecation warning (avoid overlogging).
+ self._in_target_context_manager = False
+
+ # Stores a Jinja template that formats chat histories into tokenizable strings
+ self.chat_template = kwargs.pop("chat_template", None)
+ if isinstance(self.chat_template, (list, tuple)):
+ # Chat templates are stored as lists of dicts with fixed key names,
+ # we reconstruct that into a single dict while loading them.
+ self.chat_template = {template["name"]: template["template"] for template in self.chat_template}
+
+ super().__init__(**kwargs)
+
+ self.extra_special_tokens = kwargs.pop("extra_special_tokens", {})
+ self._set_model_specific_special_tokens(special_tokens=self.extra_special_tokens)
+
+ @property
+ def max_len_single_sentence(self) -> int:
+ """
+ `int`: The maximum length of a sentence that can be fed to the model.
+ """
+ return self.model_max_length - self.num_special_tokens_to_add(pair=False)
+
+ @property
+ def max_len_sentences_pair(self) -> int:
+ """
+ `int`: The maximum combined length of a pair of sentences that can be fed to the model.
+ """
+ return self.model_max_length - self.num_special_tokens_to_add(pair=True)
+
+ @max_len_single_sentence.setter
+ def max_len_single_sentence(self, value) -> int:
+ # For backward compatibility, allow to try to setup 'max_len_single_sentence'.
+ if value == self.model_max_length - self.num_special_tokens_to_add(pair=False) and self.verbose:
+ if not self.deprecation_warnings.get("max_len_single_sentence", False):
+ logger.warning(
+ "Setting 'max_len_single_sentence' is now deprecated. This value is automatically set up."
+ )
+ self.deprecation_warnings["max_len_single_sentence"] = True
+ else:
+ raise ValueError(
+ "Setting 'max_len_single_sentence' is now deprecated. This value is automatically set up."
+ )
+
+ @max_len_sentences_pair.setter
+ def max_len_sentences_pair(self, value) -> int:
+ # For backward compatibility, allow to try to setup 'max_len_sentences_pair'.
+ if value == self.model_max_length - self.num_special_tokens_to_add(pair=True) and self.verbose:
+ if not self.deprecation_warnings.get("max_len_sentences_pair", False):
+ logger.warning(
+ "Setting 'max_len_sentences_pair' is now deprecated. This value is automatically set up."
+ )
+ self.deprecation_warnings["max_len_sentences_pair"] = True
+ else:
+ raise ValueError("Setting 'max_len_sentences_pair' is now deprecated. This value is automatically set up.")
+
+ def _set_processor_class(self, processor_class: str):
+ """Sets processor class as an attribute."""
+ self._processor_class = processor_class
+
+ @property
+ def added_tokens_decoder(self) -> dict[int, AddedToken]:
+ raise NotImplementedError()
+
+ def __repr__(self) -> str:
+ added_tokens_decoder_rep = "\n\t".join([f"{k}: {v.__repr__()}," for k, v in self.added_tokens_decoder.items()])
+ return (
+ f"{self.__class__.__name__}(name_or_path='{self.name_or_path}',"
+ f" vocab_size={self.vocab_size}, model_max_length={self.model_max_length}, is_fast={self.is_fast},"
+ f" padding_side='{self.padding_side}', truncation_side='{self.truncation_side}',"
+ f" special_tokens={self.special_tokens_map}, clean_up_tokenization_spaces={self.clean_up_tokenization_spaces},"
+ " added_tokens_decoder={\n\t" + added_tokens_decoder_rep + "\n}\n)"
+ )
+
+ def __len__(self) -> int:
+ raise NotImplementedError()
+
+ def get_vocab(self) -> dict[str, int]:
+ """
+ Returns the vocabulary as a dictionary of token to index.
+
+ `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the
+ vocab.
+
+ Returns:
+ `dict[str, int]`: The vocabulary.
+ """
+ raise NotImplementedError()
+
+ def apply_chat_template(
+ self,
+ conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
+ tools: Optional[list[Union[dict, Callable]]] = None,
+ documents: Optional[list[dict[str, str]]] = None,
+ chat_template: Optional[str] = None,
+ add_generation_prompt: bool = False,
+ continue_final_message: bool = False,
+ tokenize: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: bool = False,
+ max_length: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_dict: bool = False,
+ return_assistant_tokens_mask: bool = False,
+ tokenizer_kwargs: Optional[dict[str, Any]] = None,
+ **kwargs,
+ ) -> Union[str, list[int], list[str], list[list[int]], BatchEncoding]:
+ """
+ Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token
+ ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to
+ determine the format and control tokens to use when converting.
+
+ Args:
+ conversation (Union[list[dict[str, str]], list[list[dict[str, str]]]]): A list of dicts
+ with "role" and "content" keys, representing the chat history so far.
+ tools (`list[Union[Dict, Callable]]`, *optional*):
+ A list of tools (callable functions) that will be accessible to the model. If the template does not
+ support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
+ giving the name, description and argument types for the tool. See our
+ [tool use guide](https://huggingface.co/docs/transformers/en/chat_extras#passing-tools)
+ for more information.
+ documents (`list[dict[str, str]]`, *optional*):
+ A list of dicts representing documents that will be accessible to the model if it is performing RAG
+ (retrieval-augmented generation). If the template does not support RAG, this argument will have no
+ effect. We recommend that each document should be a dict containing "title" and "text" keys.
+ chat_template (`str`, *optional*):
+ A Jinja template to use for this conversion. It is usually not necessary to pass anything to this
+ argument, as the model's template will be used by default.
+ add_generation_prompt (bool, *optional*):
+ If this is set, a prompt with the token(s) that indicate
+ the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model.
+ Note that this argument will be passed to the chat template, and so it must be supported in the
+ template for this argument to have any effect.
+ continue_final_message (bool, *optional*):
+ If this is set, the chat will be formatted so that the final
+ message in the chat is open-ended, without any EOS tokens. The model will continue this message
+ rather than starting a new one. This allows you to "prefill" part of
+ the model's response for it. Cannot be used at the same time as `add_generation_prompt`.
+ tokenize (`bool`, defaults to `True`):
+ Whether to tokenize the output. If `False`, the output will be a string.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
+ index) among:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ truncation (`bool`, defaults to `False`):
+ Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`.
+ max_length (`int`, *optional*):
+ Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If
+ not specified, the tokenizer's `max_length` attribute will be used as a default.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable
+ values are:
+ - `'tf'`: Return TensorFlow `tf.Tensor` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+ return_dict (`bool`, defaults to `False`):
+ Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
+ tokenizer_kwargs (`dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer.
+ return_assistant_tokens_mask (`bool`, defaults to `False`):
+ Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant,
+ the mask will contain 1. For user and system tokens, the mask will contain 0.
+ This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
+ **kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
+
+ Returns:
+ `Union[list[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This
+ output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is
+ set, will return a dict of tokenizer outputs instead.
+ """
+
+ if return_dict and not tokenize:
+ raise ValueError(
+ "`return_dict=True` is incompatible with `tokenize=False`, because there is no dict "
+ "of tokenizer outputs to return."
+ )
+
+ if return_assistant_tokens_mask and not return_dict:
+ raise ValueError("`return_assistant_tokens_mask=True` is incompatible with `return_dict=False`")
+
+ if tokenizer_kwargs is None:
+ tokenizer_kwargs = {}
+
+ chat_template = self.get_chat_template(chat_template, tools)
+
+ if isinstance(conversation, (list, tuple)) and (
+ isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
+ ):
+ conversations = conversation
+ is_batched = True
+ else:
+ conversations = [conversation]
+ is_batched = False
+
+ if continue_final_message:
+ if add_generation_prompt:
+ raise ValueError(
+ "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead."
+ )
+ if return_assistant_tokens_mask:
+ raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.")
+
+ template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
+ rendered_chat, generation_indices = render_jinja_template(
+ conversations=conversations,
+ tools=tools,
+ documents=documents,
+ chat_template=chat_template,
+ return_assistant_tokens_mask=return_assistant_tokens_mask,
+ continue_final_message=continue_final_message,
+ add_generation_prompt=add_generation_prompt,
+ **template_kwargs,
+ )
+
+ if not is_batched:
+ rendered_chat = rendered_chat[0]
+
+ if tokenize:
+ out = self(
+ rendered_chat,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ add_special_tokens=False,
+ return_tensors=return_tensors,
+ **tokenizer_kwargs,
+ )
+ if return_dict:
+ if return_assistant_tokens_mask:
+ assistant_masks = []
+ if is_batched or return_tensors:
+ input_ids = out["input_ids"]
+ else:
+ input_ids = [out["input_ids"]]
+ for i in range(len(input_ids)):
+ current_mask = [0] * len(input_ids[i])
+ for assistant_start_char, assistant_end_char in generation_indices[i]:
+ start_token = out.char_to_token(i, assistant_start_char)
+ end_token = out.char_to_token(i, assistant_end_char - 1)
+ if start_token is None:
+ # start_token is out of bounds maybe due to truncation.
+ break
+ for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])):
+ current_mask[token_id] = 1
+ assistant_masks.append(current_mask)
+
+ if not is_batched and not return_tensors:
+ assistant_masks = assistant_masks[0]
+
+ out["assistant_masks"] = assistant_masks
+
+ if return_tensors:
+ out.convert_to_tensors(tensor_type=return_tensors)
+
+ return out
+ else:
+ return out["input_ids"]
+ else:
+ return rendered_chat
+
+ def encode_message_with_chat_template(
+ self,
+ message: dict[str, str],
+ conversation_history: Optional[list[dict[str, str]]] = None,
+ **kwargs,
+ ) -> list[int]:
+ """
+ Tokenize a single message. This method is a convenience wrapper around `apply_chat_template` that allows you
+ to tokenize messages one by one. This is useful for things like token-by-token streaming.
+ This method is not guaranteed to be perfect. For some models, it may be impossible to robustly tokenize
+ single messages. For example, if the chat template adds tokens after each message, but also has a prefix that
+ is added to the entire chat, it will be impossible to distinguish a chat-start-token from a message-start-token.
+ In these cases, this method will do its best to find the correct tokenization, but it may not be perfect.
+ **Note:** This method does not support `add_generation_prompt`. If you want to add a generation prompt,
+ you should do it separately after tokenizing the conversation.
+ Args:
+ message (`dict`):
+ A dictionary with "role" and "content" keys, representing the message to tokenize.
+ conversation_history (`list[dict]`, *optional*):
+ A list of dicts with "role" and "content" keys, representing the chat history so far. If you are
+ tokenizing messages one by one, you should pass the previous messages in the conversation here.
+ **kwargs:
+ Additional kwargs to pass to the `apply_chat_template` method.
+ Returns:
+ `list[int]`: A list of token ids representing the tokenized message.
+ """
+ if "add_generation_prompt" in kwargs:
+ raise ValueError(
+ "`encode_message_with_chat_template` does not support `add_generation_prompt`. Please add the generation prompt "
+ "separately."
+ )
+
+ if conversation_history is None or len(conversation_history) == 0:
+ return self.apply_chat_template([message], add_generation_prompt=False, tokenize=True, **kwargs)
+
+ conversation = conversation_history + [message]
+ tokens = self.apply_chat_template(conversation, add_generation_prompt=False, tokenize=True, **kwargs)
+
+ prefix_tokens = self.apply_chat_template(
+ conversation_history, add_generation_prompt=False, tokenize=True, **kwargs
+ )
+ # It's possible that the prefix tokens are not a prefix of the full list of tokens.
+ # For example, if the prefix is `User: Hi` and the full conversation is `User: Hi Assistant: Hello`.
+ # In this case, we can't simply find the prefix, so we have to do something a bit more subtle.
+ # We look for the first place where the tokens differ, and that's our split point.
+ # This is not perfect, but it's the best we can do without a token-level API.
+ # To make this more robust, we could do a diff and find the longest common subsequence, but this is
+ # a good first approximation.
+ # This is particularly important for models like Llama3 that have changed their chat template to include
+ # EOS tokens after user messages.
+ min_len = min(len(prefix_tokens), len(tokens))
+ for i in range(min_len):
+ if prefix_tokens[i] != tokens[i]:
+ return tokens[i:]
+ return tokens[min_len:]
+
+ def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[list[dict]] = None) -> str:
+ """
+ Retrieve the chat template string used for tokenizing chat messages. This template is used
+ internally by the `apply_chat_template` method and can also be used externally to retrieve the model's chat
+ template for better generation tracking.
+
+ Args:
+ chat_template (`str`, *optional*):
+ A Jinja template or the name of a template to use for this conversion.
+ It is usually not necessary to pass anything to this argument,
+ as the model's template will be used by default.
+ tools (`list[Dict]`, *optional*):
+ A list of tools (callable functions) that will be accessible to the model. If the template does not
+ support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
+ giving the name, description and argument types for the tool. See our
+ [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
+ for more information.
+
+ Returns:
+ `str`: The chat template string.
+ """
+ # First, handle the cases when the model has a dict of multiple templates
+ if isinstance(self.chat_template, dict):
+ template_dict = self.chat_template
+ if chat_template is not None and chat_template in template_dict:
+ # The user can pass the name of a template to the chat template argument instead of an entire template
+ chat_template = template_dict[chat_template]
+ elif chat_template is None:
+ if tools is not None and "tool_use" in template_dict:
+ chat_template = template_dict["tool_use"]
+ elif "default" in template_dict:
+ chat_template = template_dict["default"]
+ else:
+ raise ValueError(
+ "This model has multiple chat templates with no default specified! Please either pass a chat "
+ "template or the name of the template you wish to use to the `chat_template` argument. Available "
+ f"template names are {sorted(template_dict.keys())}."
+ )
+
+ elif chat_template is None:
+ # These are the cases when the model has a single template
+ # priority: `chat_template` argument > `tokenizer.chat_template`
+ if self.chat_template is not None:
+ chat_template = self.chat_template
+ else:
+ raise ValueError(
+ "Cannot use chat template functions because tokenizer.chat_template is not set and no template "
+ "argument was passed! For information about writing templates and setting the "
+ "tokenizer.chat_template attribute, please see the documentation at "
+ "https://huggingface.co/docs/transformers/main/en/chat_templating"
+ )
+
+ return chat_template
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ *init_inputs,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ token: Optional[Union[str, bool]] = None,
+ revision: str = "main",
+ trust_remote_code=False,
+ **kwargs,
+ ):
+ r"""
+ Instantiate a [`~tokenization_utils_base.PreTrainedTokenizerBase`] (or a derived class) from a predefined
+ tokenizer.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ Can be either:
+
+ - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
+ using the [`~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`] method, e.g.,
+ `./my_model_directory/`.
+ - (**Deprecated**, not applicable to all derived classes) A path or url to a single saved vocabulary
+ file (if and only if the tokenizer only requires a single vocabulary file like Bert or XLNet), e.g.,
+ `./my_model_directory/vocab.txt`.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download the vocabulary files and override the cached versions if they
+ exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `hf auth login` (stored in `~/.huggingface`).
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether or not to only rely on local files and not to attempt to download any files.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ subfolder (`str`, *optional*):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for
+ facebook/rag-token-base), specify it here.
+ inputs (additional positional arguments, *optional*):
+ Will be passed along to the Tokenizer `__init__` method.
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
+ execute code present on the Hub on your local machine.
+ kwargs (additional keyword arguments, *optional*):
+ Will be passed to the Tokenizer `__init__` method. Can be used to set special tokens like `bos_token`,
+ `eos_token`, `unk_token`, `sep_token`, `pad_token`, `cls_token`, `mask_token`,
+ `additional_special_tokens`. See parameters in the `__init__` for more details.
+
+
+
+ Passing `token=True` is required when you want to use a private model.
+
+
+
+ Examples:
+
+ ```python
+ # We can't instantiate directly the base class *PreTrainedTokenizerBase* so let's show our examples on a derived class: BertTokenizer
+ # Download vocabulary from huggingface.co and cache.
+ tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
+
+ # Download vocabulary from huggingface.co (user-uploaded) and cache.
+ tokenizer = BertTokenizer.from_pretrained("dbmdz/bert-base-german-cased")
+
+ # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
+ tokenizer = BertTokenizer.from_pretrained("./test/saved_model/")
+
+ # If the tokenizer uses a single vocabulary file, you can point directly to this file
+ tokenizer = BertTokenizer.from_pretrained("./test/saved_model/my_vocab.txt")
+
+ # You can link tokens to special vocabulary when instantiating
+ tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased", unk_token="")
+ # You should be sure '' is in the vocabulary when doing that.
+ # Otherwise use tokenizer.add_special_tokens({'unk_token': ''}) instead)
+ assert tokenizer.unk_token == ""
+ ```"""
+ resume_download = kwargs.pop("resume_download", None)
+ proxies = kwargs.pop("proxies", None)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ subfolder = kwargs.pop("subfolder", None)
+ from_pipeline = kwargs.pop("_from_pipeline", None)
+ from_auto_class = kwargs.pop("_from_auto", False)
+ commit_hash = kwargs.pop("_commit_hash", None)
+ gguf_file = kwargs.get("gguf_file")
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ user_agent = {"file_type": "tokenizer", "from_auto_class": from_auto_class, "is_fast": "Fast" in cls.__name__}
+ if from_pipeline is not None:
+ user_agent["using_pipeline"] = from_pipeline
+
+ if is_offline_mode() and not local_files_only:
+ logger.info("Offline mode: forcing local_files_only=True")
+ local_files_only = True
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ vocab_files = {}
+ init_configuration = {}
+
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ single_file_id = None
+ if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
+ if len(cls.vocab_files_names) > 1 and not gguf_file:
+ raise ValueError(
+ f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not "
+ "supported for this tokenizer. Use a model identifier or the path to a directory instead."
+ )
+ warnings.warn(
+ f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is deprecated and "
+ "won't be possible anymore in v5. Use a model identifier or the path to a directory instead.",
+ FutureWarning,
+ )
+ file_id = list(cls.vocab_files_names.keys())[0]
+
+ vocab_files[file_id] = pretrained_model_name_or_path
+ single_file_id = file_id
+ else:
+ if gguf_file:
+ vocab_files["vocab_file"] = gguf_file
+ else:
+ # At this point pretrained_model_name_or_path is either a directory or a model identifier name
+ additional_files_names = {
+ "added_tokens_file": ADDED_TOKENS_FILE, # kept only for legacy
+ "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, # kept only for legacy
+ "tokenizer_config_file": TOKENIZER_CONFIG_FILE,
+ # tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders
+ "tokenizer_file": FULL_TOKENIZER_FILE,
+ "chat_template_file": CHAT_TEMPLATE_FILE,
+ }
+
+ vocab_files = {**cls.vocab_files_names, **additional_files_names}
+ if "tokenizer_file" in vocab_files:
+ # Try to get the tokenizer config to see if there are versioned tokenizer files.
+ fast_tokenizer_file = FULL_TOKENIZER_FILE
+
+ try:
+ resolved_config_file = cached_file(
+ pretrained_model_name_or_path,
+ TOKENIZER_CONFIG_FILE,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ local_files_only=local_files_only,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ _raise_exceptions_for_missing_entries=False,
+ _commit_hash=commit_hash,
+ )
+ except OSError:
+ # Re-raise any error raised by cached_file in order to get a helpful error message
+ raise
+ except Exception:
+ # For any other exception, we throw a generic error.
+ raise OSError(
+ f"Can't load tokenizer for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing all relevant files for a {cls.__name__} tokenizer."
+ )
+
+ commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
+ if resolved_config_file is not None:
+ with open(resolved_config_file, encoding="utf-8") as reader:
+ tokenizer_config = json.load(reader)
+ if "fast_tokenizer_files" in tokenizer_config:
+ fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
+ vocab_files["tokenizer_file"] = fast_tokenizer_file
+
+ # This block looks for any extra chat template files
+ if is_local:
+ template_dir = Path(pretrained_model_name_or_path, CHAT_TEMPLATE_DIR)
+ if template_dir.is_dir():
+ for template_file in template_dir.glob("*.jinja"):
+ template_name = template_file.name.removesuffix(".jinja")
+ vocab_files[f"chat_template_{template_name}"] = (
+ f"{CHAT_TEMPLATE_DIR}/{template_file.name}"
+ )
+ else:
+ for template in list_repo_templates(
+ pretrained_model_name_or_path,
+ local_files_only=local_files_only,
+ revision=revision,
+ cache_dir=cache_dir,
+ token=token,
+ ):
+ template = template.removesuffix(".jinja")
+ vocab_files[f"chat_template_{template}"] = f"{CHAT_TEMPLATE_DIR}/{template}.jinja"
+
+ remote_files = []
+ if not is_local and not local_files_only:
+ try:
+ remote_files = list_repo_files(pretrained_model_name_or_path)
+ except Exception:
+ remote_files = []
+ elif pretrained_model_name_or_path and os.path.isdir(pretrained_model_name_or_path):
+ remote_files = os.listdir(pretrained_model_name_or_path)
+
+ if "tokenizer_file" in vocab_files and not re.search(vocab_files["tokenizer_file"], "".join(remote_files)):
+ # mistral tokenizer names are different, but we can still convert them if
+ # mistral common is not there
+ other_pattern = r"tekken\.json|tokenizer\.model\.*"
+ if match := re.search(other_pattern, "\n".join(remote_files)):
+ vocab_files["vocab_file"] = match.group()
+
+ resolved_vocab_files = {}
+ for file_id, file_path in vocab_files.items():
+ if file_path is None:
+ resolved_vocab_files[file_id] = None
+ elif single_file_id == file_id:
+ if os.path.isfile(file_path):
+ resolved_vocab_files[file_id] = file_path
+ elif is_remote_url(file_path):
+ resolved_vocab_files[file_id] = download_url(file_path, proxies=proxies)
+ else:
+ try:
+ resolved_vocab_files[file_id] = cached_file(
+ pretrained_model_name_or_path,
+ file_path,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ token=token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ _raise_exceptions_for_missing_entries=False,
+ _commit_hash=commit_hash,
+ )
+ except OSError:
+ # Re-raise any error raised by cached_file in order to get a helpful error message
+ raise
+ except Exception:
+ # For any other exception, we throw a generic error.
+ raise OSError(
+ f"Can't load tokenizer for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing all relevant files for a {cls.__name__} tokenizer."
+ )
+ commit_hash = extract_commit_hash(resolved_vocab_files[file_id], commit_hash)
+
+ for file_id, file_path in vocab_files.items():
+ if file_id not in resolved_vocab_files:
+ continue
+
+ if is_local:
+ logger.info(f"loading file {file_path}")
+ else:
+ logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}")
+
+ return cls._from_pretrained(
+ resolved_vocab_files,
+ pretrained_model_name_or_path,
+ init_configuration,
+ *init_inputs,
+ token=token,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ _commit_hash=commit_hash,
+ _is_local=is_local,
+ trust_remote_code=trust_remote_code,
+ **kwargs,
+ )
+
+ @classmethod
+ def _from_pretrained(
+ cls,
+ resolved_vocab_files,
+ pretrained_model_name_or_path,
+ init_configuration,
+ *init_inputs,
+ token=None,
+ cache_dir=None,
+ local_files_only=False,
+ _commit_hash=None,
+ _is_local=False,
+ trust_remote_code=False,
+ **kwargs,
+ ):
+ # We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json
+ # file or if `from_slow` is set to True.
+ from_slow = kwargs.get("from_slow", False)
+ gguf_file = kwargs.get("gguf_file")
+ has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None
+
+ # If one passes a GGUF file path to `gguf_file` there is no need for this check as the tokenizer will be
+ # loaded directly from the GGUF file.
+ if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None and not gguf_file:
+ slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained(
+ copy.deepcopy(resolved_vocab_files),
+ pretrained_model_name_or_path,
+ copy.deepcopy(init_configuration),
+ *init_inputs,
+ token=token,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ _commit_hash=_commit_hash,
+ **(copy.deepcopy(kwargs)),
+ )
+ else:
+ slow_tokenizer = None
+
+ # Prepare tokenizer initialization kwargs
+ # Did we saved some inputs and kwargs to reload ?
+ tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None)
+ if tokenizer_config_file is not None:
+ with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
+ init_kwargs = json.load(tokenizer_config_handle)
+ # First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers.
+ config_tokenizer_class = init_kwargs.get("tokenizer_class")
+ init_kwargs.pop("tokenizer_class", None)
+ if not has_tokenizer_file:
+ init_kwargs.pop("tokenizer_file", None)
+ saved_init_inputs = init_kwargs.pop("init_inputs", ())
+ if not init_inputs:
+ init_inputs = saved_init_inputs
+ else:
+ config_tokenizer_class = None
+ init_kwargs = init_configuration
+
+ # If independent chat template file(s) exist, they take priority over template entries in the tokenizer config
+ chat_templates = {}
+ chat_template_file = resolved_vocab_files.pop("chat_template_file", None)
+ extra_chat_templates = [key for key in resolved_vocab_files if key.startswith("chat_template_")]
+ if chat_template_file is not None:
+ with open(chat_template_file, encoding="utf-8") as chat_template_handle:
+ chat_templates["default"] = chat_template_handle.read()
+ for extra_chat_template in extra_chat_templates:
+ template_file = resolved_vocab_files.pop(extra_chat_template, None)
+ if template_file is None:
+ continue # I think this should never happen, but just in case
+ template_name = extra_chat_template.removeprefix("chat_template_")
+ with open(template_file, encoding="utf8") as chat_template_handle:
+ chat_templates[template_name] = chat_template_handle.read()
+ if len(chat_templates) == 1 and "default" in chat_templates:
+ init_kwargs["chat_template"] = chat_templates["default"]
+ elif chat_templates:
+ init_kwargs["chat_template"] = chat_templates
+
+ if not _is_local:
+ if "auto_map" in init_kwargs:
+ # For backward compatibility with odl format.
+ if isinstance(init_kwargs["auto_map"], (tuple, list)):
+ init_kwargs["auto_map"] = {"AutoTokenizer": init_kwargs["auto_map"]}
+
+ if config_tokenizer_class is None:
+ # Matt: This entire block is only used to decide if the tokenizer class matches the class in the repo.
+ # If not, it raises a warning, but otherwise continues. Since we mostly load tokenizers with
+ # AutoTokenizer these days, it seems like a lot of work (and a source of bugs) for little gain.
+ # Maybe we can just remove this entirely?
+ from .models.auto.configuration_auto import AutoConfig # tests_ignore
+
+ # Second attempt. If we have not yet found tokenizer_class, let's try to use the config.
+ try:
+ config = AutoConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ token=token,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ trust_remote_code=trust_remote_code,
+ _commit_hash=_commit_hash,
+ )
+ config_tokenizer_class = config.tokenizer_class
+ except (OSError, ValueError, KeyError):
+ # skip if an error occurred.
+ config = None
+ if config_tokenizer_class is None:
+ # Third attempt. If we have not yet found the original type of the tokenizer,
+ # we are loading we see if we can infer it from the type of the configuration file
+ from .models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES # tests_ignore
+
+ if hasattr(config, "model_type"):
+ model_type = config.model_type
+ else:
+ # Fallback: use pattern matching on the string.
+ model_type = None
+ for pattern in TOKENIZER_MAPPING_NAMES:
+ if pattern in str(pretrained_model_name_or_path):
+ model_type = pattern
+ break
+
+ if model_type is not None:
+ config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING_NAMES.get(
+ model_type, (None, None)
+ )
+ if config_tokenizer_class is None:
+ config_tokenizer_class = config_tokenizer_class_fast
+
+ if config_tokenizer_class is not None:
+ if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""):
+ logger.warning(
+ "The tokenizer class you load from this checkpoint is not the same type as the class this"
+ " function is called from. It may result in unexpected tokenization. \nThe tokenizer class you"
+ f" load from this checkpoint is '{config_tokenizer_class}'. \nThe class this function is called"
+ f" from is '{cls.__name__}'."
+ )
+
+ # Update with newly provided kwargs
+ init_kwargs.update(kwargs)
+
+ # Merge resolved_vocab_files arguments in init_kwargs.
+ added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None)
+ special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
+ for args_name, file_path in resolved_vocab_files.items():
+ if args_name not in init_kwargs:
+ init_kwargs[args_name] = file_path
+ tokenizer_file = resolved_vocab_files.pop("tokenizer_file", None)
+
+ if slow_tokenizer is not None:
+ init_kwargs["__slow_tokenizer"] = slow_tokenizer
+ init_kwargs["name_or_path"] = pretrained_model_name_or_path
+
+ #### Handle tokenizer serialization of added and special tokens
+ added_tokens_decoder: dict[int, AddedToken] = {}
+ added_tokens_map: dict[str, AddedToken] = {}
+ # if we have info on the slow added tokens
+ if "added_tokens_decoder" in init_kwargs:
+ for idx, token in init_kwargs["added_tokens_decoder"].items():
+ if isinstance(token, dict):
+ token = AddedToken(**token)
+ if isinstance(token, AddedToken):
+ added_tokens_decoder[int(idx)] = token
+ added_tokens_map[str(token)] = token
+ else:
+ raise TypeError(
+ f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary or an AddedToken instance"
+ )
+ else:
+ # begin legacy: read the added_tokens_file and update kwargs with special_tokens_map if modified
+ if special_tokens_map_file is not None:
+ with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
+ special_tokens_map = json.load(special_tokens_map_handle)
+ for key, value in special_tokens_map.items():
+ if key in kwargs and kwargs[key]:
+ # This value has already been redefined by the kwargs
+ # We keep this new value and ignore the one stored in the special_tokens_map_file
+ continue
+ if isinstance(value, dict):
+ value["special"] = True
+ value = AddedToken(**value)
+ elif key == "additional_special_tokens" and isinstance(value, list):
+ additional_special_tokens = init_kwargs.pop("additional_special_tokens", []) or []
+ for token in value:
+ if isinstance(token, dict):
+ token["special"] = True
+ token = AddedToken(**token)
+ if token not in additional_special_tokens:
+ additional_special_tokens.append(token)
+ value = additional_special_tokens
+ init_kwargs[key] = value
+
+ # slow -> slow|fast, legacy: convert the `"added_tokens.json"` file to `added_tokens_decoder`.
+ # this is for legacy purpose. We don't add the tokens after init for efficiency.
+ if added_tokens_file is not None:
+ special_tokens = []
+ for key in cls.SPECIAL_TOKENS_ATTRIBUTES & init_kwargs.keys():
+ if init_kwargs[key] is not None:
+ if key == "additional_special_tokens":
+ special_tokens += [str(token) for token in init_kwargs[key]]
+ else:
+ special_tokens.append(str(init_kwargs[key]))
+
+ with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
+ added_tok_encoder = json.load(added_tokens_handle)
+ for str_token, index in added_tok_encoder.items():
+ # if index not in added_tokens_decoder and str_token not in added_tokens_map:
+ special = str_token in special_tokens
+ added_tokens_decoder[index] = AddedToken(
+ str_token, rstrip=False, lstrip=False, normalized=not special, special=special
+ )
+ added_tokens_map[str(token)] = added_tokens_decoder[index]
+
+ # allows converting a fast -> slow: add the `tokenizer.json`'s `"added_tokens"` to the slow tokenizer
+ # if `tokenizer_config.json` is `None`
+ if tokenizer_file is not None:
+ # This is for slow so can be done before
+ with open(tokenizer_file, encoding="utf-8") as tokenizer_file_handle:
+ tokenizer_file_handle = json.load(tokenizer_file_handle)
+ added_tokens = tokenizer_file_handle.pop("added_tokens")
+ for serialized_tokens in added_tokens:
+ idx = serialized_tokens.pop("id")
+ added_tokens_decoder[idx] = AddedToken(**serialized_tokens)
+ added_tokens_map[str(added_tokens_decoder[idx])] = added_tokens_decoder[idx]
+ # end legacy
+
+ # Passing AddedTokens and not strings to the class to prevent it from casting the string to a different AddedToken
+ # convert {'__type': 'AddedToken', 'content': '', 'lstrip': False, 'normalized': True, ...} to AddedTokens
+ init_kwargs["added_tokens_decoder"] = added_tokens_decoder
+ init_kwargs = cls.convert_added_tokens(init_kwargs, save=False)
+ for key in cls.SPECIAL_TOKENS_ATTRIBUTES & init_kwargs.keys():
+ if added_tokens_map != {} and init_kwargs[key] is not None:
+ if key != "additional_special_tokens":
+ init_kwargs[key] = added_tokens_map.get(str(init_kwargs[key]), init_kwargs[key])
+
+ # Instantiate the tokenizer.
+ try:
+ tokenizer = cls(*init_inputs, **init_kwargs)
+ except import_protobuf_decode_error():
+ logger.info(
+ "Unable to load tokenizer model from SPM, loading from TikToken will be attempted instead."
+ "(Google protobuf error: Tried to load SPM model with non-SPM vocab file).",
+ )
+ return False
+ except RuntimeError as e:
+ if "sentencepiece_processor.cc" in str(e):
+ logger.info(
+ "Unable to load tokenizer model from SPM, loading from TikToken will be attempted instead."
+ "(SentencePiece RuntimeError: Tried to load SPM model with non-SPM vocab file).",
+ )
+ return False
+ except OSError:
+ raise OSError(
+ "Unable to load vocabulary from file. "
+ "Please check that the provided vocabulary is accessible and not corrupted."
+ )
+
+ if added_tokens_decoder != {} and max(list(added_tokens_decoder.keys())[-1], 0) > tokenizer.vocab_size:
+ logger.info(
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are"
+ " fine-tuned or trained."
+ )
+ try:
+ vocab_size = tokenizer.vocab_size
+ except NotImplementedError:
+ vocab_size = 0
+
+ # Optionally patches mistral tokenizers with wrong regex
+ if (
+ vocab_size > 100000
+ and hasattr(tokenizer, "_tokenizer")
+ and getattr(tokenizer._tokenizer, "pre_tokenizer", None) is not None
+ ):
+ tokenizer = cls._patch_mistral_regex(
+ tokenizer,
+ pretrained_model_name_or_path,
+ token=token,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ _commit_hash=_commit_hash,
+ _is_local=_is_local,
+ init_kwargs=init_kwargs,
+ fix_mistral_regex=kwargs.get("fix_mistral_regex"),
+ )
+
+ return tokenizer
+
+ @classmethod
+ def _patch_mistral_regex(
+ cls,
+ tokenizer,
+ pretrained_model_name_or_path,
+ token=None,
+ cache_dir=None,
+ local_files_only=False,
+ _commit_hash=None,
+ _is_local=False,
+ init_kwargs=None,
+ fix_mistral_regex=None,
+ ):
+ """
+ Patches mistral related tokenizers with incorrect regex if detected
+ 1) Local file with an associated config saved next to it
+ >> Model type one of the mistral models (on older versions)
+ 2) Remote models on the hub from official mistral models
+ >> Tags including `base_model:.*mistralai`
+ """
+ from huggingface_hub import model_info
+
+ def is_base_mistral(model_id: str) -> bool:
+ model = model_info(model_id)
+ if model.tags is not None:
+ if re.search("base_model:.*mistralai", "".join(model.tags)):
+ return True
+ return False
+
+ if is_offline_mode():
+ _is_local = True
+
+ if pretrained_model_name_or_path is not None and (
+ _is_local or (not _is_local and is_base_mistral(pretrained_model_name_or_path))
+ ):
+ _config_file = cached_file(
+ pretrained_model_name_or_path,
+ "config.json",
+ cache_dir=cache_dir,
+ token=token,
+ local_files_only=local_files_only,
+ _raise_exceptions_for_missing_entries=False,
+ _raise_exceptions_for_connection_errors=False,
+ _commit_hash=_commit_hash,
+ )
+
+ # Detected using a (local) mistral tokenizer
+ mistral_config_detected = False
+ if _config_file is not None:
+ with open(_config_file, encoding="utf-8") as f:
+ _config = json.load(f)
+ transformers_version = _config.get("transformers_version")
+ transformers_model_type = _config.get("model_type")
+
+ # Detect if we can skip the mistral fix by
+ # a) having a non-mistral tokenizer
+ # b) fixed version of transformers
+ if transformers_version and version.parse(transformers_version) <= version.parse("4.57.2"):
+ if (
+ _is_local
+ and transformers_model_type is not None
+ and transformers_model_type
+ not in [
+ "mistral",
+ "mistral3",
+ "voxtral",
+ "ministral",
+ "pixtral",
+ ]
+ ):
+ return tokenizer
+ elif transformers_version and version.parse(transformers_version) >= version.parse("5.0.0"):
+ return tokenizer
+
+ mistral_config_detected = True
+
+ if mistral_config_detected or (not _is_local and is_base_mistral(pretrained_model_name_or_path)):
+ # Expose the `fix_mistral_regex` flag on the tokenizer when provided, even if no correction is applied.
+ if init_kwargs and "fix_mistral_regex" in init_kwargs:
+ setattr(tokenizer, "fix_mistral_regex", init_kwargs["fix_mistral_regex"])
+
+ # only warn if its not explicitly passed
+ if fix_mistral_regex is None and not getattr(tokenizer, "fix_mistral_regex", False):
+ setattr(tokenizer, "fix_mistral_regex", False)
+ logger.warning(
+ f"The tokenizer you are loading from '{pretrained_model_name_or_path}'"
+ f" with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e."
+ " This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue."
+ )
+ elif fix_mistral_regex is True or getattr(tokenizer, "fix_mistral_regex", False):
+ setattr(tokenizer, "fix_mistral_regex", True)
+ import tokenizers
+
+ tokenizer.backend_tokenizer.pre_tokenizer[0] = tokenizers.pre_tokenizers.Split(
+ pattern=tokenizers.Regex(
+ r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+"
+ ),
+ behavior="isolated",
+ )
+ return tokenizer
+
+ @staticmethod
+ def _eventually_correct_t5_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length):
+ # This method should be deleted in Transformers v5
+ # Its only purpose is to potentially throw a warning
+ # that incorrectly defined max lengths of T5's tokenizer are used
+ # which we will correct in Transformers v5.
+ return max_model_length
+
+ @classmethod
+ def convert_added_tokens(cls, obj: Union[AddedToken, Any], save=False, add_type_field=True):
+ if isinstance(obj, dict) and "__type" in obj and obj["__type"] == "AddedToken":
+ obj.pop("__type")
+ return AddedToken(**obj)
+ if isinstance(obj, AddedToken) and save:
+ obj = obj.__getstate__()
+ if add_type_field:
+ obj["__type"] = "AddedToken"
+ else:
+ # Don't save "special" for previous tokenizers
+ obj.pop("special")
+ return obj
+ elif isinstance(obj, (list, tuple)):
+ return [cls.convert_added_tokens(o, save=save, add_type_field=add_type_field) for o in obj]
+ elif isinstance(obj, dict):
+ return {k: cls.convert_added_tokens(v, save=save, add_type_field=add_type_field) for k, v in obj.items()}
+ return obj
+
+ def save_chat_templates(
+ self,
+ save_directory: Union[str, os.PathLike],
+ tokenizer_config: dict,
+ filename_prefix: Optional[str],
+ save_jinja_files: bool,
+ ):
+ """
+ Writes chat templates out to the save directory if we're using the new format, and removes them from
+ the tokenizer config if present. If we're using the legacy format, it doesn't write any files, and instead
+ writes the templates to the tokenizer config in the correct format.
+ """
+ chat_template_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_FILE
+ )
+ chat_template_dir = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_DIR
+ )
+
+ saved_raw_chat_template_files = []
+ if save_jinja_files and isinstance(self.chat_template, str):
+ # New format for single templates is to save them as chat_template.jinja
+ with open(chat_template_file, "w", encoding="utf-8") as f:
+ f.write(self.chat_template)
+ logger.info(f"chat template saved in {chat_template_file}")
+ saved_raw_chat_template_files.append(chat_template_file)
+ if "chat_template" in tokenizer_config:
+ tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too
+ elif save_jinja_files and isinstance(self.chat_template, dict):
+ # New format for multiple templates is to save the default as chat_template.jinja
+ # and the other templates in the chat_templates/ directory
+ for template_name, template in self.chat_template.items():
+ if template_name == "default":
+ with open(chat_template_file, "w", encoding="utf-8") as f:
+ f.write(self.chat_template["default"])
+ logger.info(f"chat template saved in {chat_template_file}")
+ saved_raw_chat_template_files.append(chat_template_file)
+ else:
+ Path(chat_template_dir).mkdir(exist_ok=True)
+ template_filepath = os.path.join(chat_template_dir, f"{template_name}.jinja")
+ with open(template_filepath, "w", encoding="utf-8") as f:
+ f.write(template)
+ logger.info(f"chat template saved in {template_filepath}")
+ saved_raw_chat_template_files.append(template_filepath)
+ if "chat_template" in tokenizer_config:
+ tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too
+ elif isinstance(self.chat_template, dict):
+ # Legacy format for multiple templates:
+ # chat template dicts are saved to the config as lists of dicts with fixed key names.
+ tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()]
+ elif self.chat_template is not None:
+ # Legacy format for single templates: Just make them a key in tokenizer_config.json
+ tokenizer_config["chat_template"] = self.chat_template
+ return tokenizer_config, saved_raw_chat_template_files
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ legacy_format: Optional[bool] = None,
+ filename_prefix: Optional[str] = None,
+ push_to_hub: bool = False,
+ **kwargs,
+ ) -> tuple[str, ...]:
+ """
+ Save the full tokenizer state.
+
+
+ This method make sure the full tokenizer can then be re-loaded using the
+ [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`] class method..
+
+ Warning,None This won't save modifications you may have applied to the tokenizer after the instantiation (for
+ instance, modifying `tokenizer.do_lower_case` after creation).
+
+ Args:
+ save_directory (`str` or `os.PathLike`): The path to a directory where the tokenizer will be saved.
+ legacy_format (`bool`, *optional*):
+ Only applicable for a fast tokenizer. If unset (default), will save the tokenizer in the unified JSON
+ format as well as in legacy format if it exists, i.e. with tokenizer specific vocabulary and a separate
+ added_tokens files.
+
+ If `False`, will only save the tokenizer in the unified JSON format. This format is incompatible with
+ "slow" tokenizers (not powered by the *tokenizers* library), so the tokenizer will not be able to be
+ loaded in the corresponding "slow" tokenizer.
+
+ If `True`, will save the tokenizer in legacy format. If the "slow" tokenizer doesn't exits, a value
+ error is raised.
+ filename_prefix (`str`, *optional*):
+ A prefix to add to the names of the files saved by the tokenizer.
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ kwargs (`dict[str, Any]`, *optional*):
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+
+ Returns:
+ A tuple of `str`: The files saved.
+ """
+ use_auth_token = kwargs.pop("use_auth_token", None)
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if kwargs.get("token") is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ kwargs["token"] = use_auth_token
+
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ if push_to_hub:
+ commit_message = kwargs.pop("commit_message", None)
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+ repo_id = self._create_repo(repo_id, **kwargs)
+ files_timestamps = self._get_files_timestamps(save_directory)
+
+ special_tokens_map_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + SPECIAL_TOKENS_MAP_FILE
+ )
+ tokenizer_config_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_CONFIG_FILE
+ )
+
+ tokenizer_config = copy.deepcopy(self.init_kwargs)
+
+ # Let's save the init kwargs
+ target_keys = set(self.init_kwargs.keys())
+ # Let's save the special tokens map (only the strings)
+ target_keys.update(["model_max_length", "clean_up_tokenization_spaces"])
+
+ for k in target_keys:
+ if hasattr(self, k):
+ tokenizer_config[k] = getattr(self, k)
+
+ # Let's make sure we properly save the special tokens
+ tokenizer_config.update(self.special_tokens_map)
+ if "extra_special_tokens" not in tokenizer_config:
+ tokenizer_config["extra_special_tokens"] = self.extra_special_tokens
+ tokenizer_config.update(self.extra_special_tokens)
+
+ save_jinja_files = kwargs.get("save_jinja_files", True)
+ tokenizer_config, saved_raw_chat_template_files = self.save_chat_templates(
+ save_directory, tokenizer_config, filename_prefix, save_jinja_files
+ )
+
+ if len(self.init_inputs) > 0:
+ tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)
+ for file_id in self.vocab_files_names:
+ tokenizer_config.pop(file_id, None)
+
+ # no typefields, this way old fast and slow can load it
+ tokenizer_config = self.convert_added_tokens(tokenizer_config, add_type_field=True, save=True)
+
+ # Process added tokens separately: allows previous versions to ignore it!
+ added_tokens = {}
+ for key, value in self.added_tokens_decoder.items():
+ added_tokens[key] = value.__getstate__()
+ tokenizer_config["added_tokens_decoder"] = added_tokens
+
+ # Add tokenizer class to the tokenizer config to be able to reload it with from_pretrained
+ tokenizer_class = self.__class__.__name__
+ # Remove the Fast at the end if we can save the slow tokenizer
+ if tokenizer_class.endswith("Fast") and getattr(self, "can_save_slow_tokenizer", False):
+ tokenizer_class = tokenizer_class[:-4]
+ tokenizer_config["tokenizer_class"] = tokenizer_class
+ if getattr(self, "_auto_map", None) is not None:
+ tokenizer_config["auto_map"] = self._auto_map
+ if getattr(self, "_processor_class", None) is not None:
+ tokenizer_config["processor_class"] = self._processor_class
+
+ # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
+ # loaded from the Hub.
+ if self._auto_class is not None:
+ custom_object_save(self, save_directory, config=tokenizer_config)
+
+ # remove private information
+ if "name_or_path" in tokenizer_config:
+ tokenizer_config.pop("name_or_path")
+ tokenizer_config.pop("special_tokens_map_file", None)
+ tokenizer_config.pop("tokenizer_file", None)
+ if "device_map" in tokenizer_config:
+ tokenizer_config.pop("device_map")
+
+ with open(tokenizer_config_file, "w", encoding="utf-8") as f:
+ out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
+ f.write(out_str)
+ logger.info(f"tokenizer config file saved in {tokenizer_config_file}")
+
+ # Sanitize AddedTokens in special_tokens_map
+
+ # kept for forward compatibility, will be removed in transoformers 5. Typefields are not saved for FC, special should not be save either
+ write_dict = self.convert_added_tokens(self.special_tokens_map_extended, save=True, add_type_field=False)
+ with open(special_tokens_map_file, "w", encoding="utf-8") as f:
+ out_str = json.dumps(write_dict, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
+ f.write(out_str)
+ logger.info(f"Special tokens file saved in {special_tokens_map_file}")
+
+ file_names = (tokenizer_config_file, special_tokens_map_file, *saved_raw_chat_template_files)
+
+ save_files = self._save_pretrained(
+ save_directory=save_directory,
+ file_names=file_names,
+ legacy_format=legacy_format,
+ filename_prefix=filename_prefix,
+ )
+
+ if push_to_hub:
+ self._upload_modified_files(
+ save_directory,
+ repo_id,
+ files_timestamps,
+ commit_message=commit_message,
+ token=kwargs.get("token"),
+ )
+
+ return save_files
+
+ def _save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ file_names: tuple[str, ...],
+ legacy_format: Optional[bool] = None,
+ filename_prefix: Optional[str] = None,
+ ) -> tuple[str, ...]:
+ """
+ Save a tokenizer using the slow-tokenizer/legacy format: vocabulary + added tokens.
+
+ Fast tokenizers can also be saved in a unique JSON file containing {config + vocab + added-tokens} using the
+ specific [`~tokenization_utils_fast.PreTrainedTokenizerFast._save_pretrained`]
+ """
+ if legacy_format is False:
+ raise ValueError(
+ "Only fast tokenizers (instances of PreTrainedTokenizerFast) can be saved in non legacy format."
+ )
+
+ save_directory = str(save_directory)
+
+ added_tokens_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE
+ )
+ # the new get_added_vocab() also returns special tokens and tokens that have an index < vocab_size
+ added_vocab = {tok: index for tok, index in self.added_tokens_encoder.items() if index >= self.vocab_size}
+ if added_vocab:
+ with open(added_tokens_file, "w", encoding="utf-8") as f:
+ out_str = json.dumps(added_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
+ f.write(out_str)
+ logger.info(f"added tokens file saved in {added_tokens_file}")
+
+ vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix)
+
+ return file_names + vocab_files + (added_tokens_file,)
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str, ...]:
+ """
+ Save only the vocabulary of the tokenizer (vocabulary + added tokens).
+
+ This method won't save the configuration and special token mappings of the tokenizer. Use
+ [`~PreTrainedTokenizerFast._save_pretrained`] to save the whole state of the tokenizer.
+
+ Args:
+ save_directory (`str`):
+ The directory in which to save the vocabulary.
+ filename_prefix (`str`, *optional*):
+ An optional prefix to add to the named of the saved files.
+
+ Returns:
+ `tuple(str)`: Paths to the files saved.
+ """
+ raise NotImplementedError
+
+ def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> list[str]:
+ """
+ Converts a string into a sequence of tokens, replacing unknown tokens with the `unk_token`.
+
+ Args:
+ text (`str`):
+ The sequence to be encoded.
+ pair (`str`, *optional*):
+ A second sequence to be encoded with the first.
+ add_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not to add the special tokens associated with the corresponding model.
+ kwargs (additional keyword arguments, *optional*):
+ Will be passed to the underlying model specific encode method. See details in
+ [`~PreTrainedTokenizerBase.__call__`]
+
+ Returns:
+ `list[str]`: The list of tokens.
+ """
+ raise NotImplementedError
+
+ @add_end_docstrings(
+ ENCODE_KWARGS_DOCSTRING,
+ """
+ **kwargs: Passed along to the `.tokenize()` method.
+ """,
+ """
+ Returns:
+ `list[int]`, `torch.Tensor`, `tf.Tensor` or `np.ndarray`: The tokenized ids of the text.
+ """,
+ )
+ def encode(
+ self,
+ text: Union[TextInput, PreTokenizedInput, EncodedInput],
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy, None] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs,
+ ) -> list[int]:
+ """
+ Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary.
+
+ Same as doing `self.convert_tokens_to_ids(self.tokenize(text))`.
+
+ Args:
+ text (`str`, `list[str]` or `list[int]`):
+ The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the
+ `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
+ method).
+ text_pair (`str`, `list[str]` or `list[int]`, *optional*):
+ Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using
+ the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
+ method).
+ """
+ encoded_inputs = self.encode_plus(
+ text,
+ text_pair=text_pair,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ **kwargs,
+ )
+
+ return encoded_inputs["input_ids"]
+
+ def num_special_tokens_to_add(self, pair: bool = False) -> int:
+ raise NotImplementedError
+
+ def _get_padding_truncation_strategies(
+ self, padding=False, truncation=None, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
+ ):
+ """
+ Find the correct padding/truncation strategy
+ """
+
+ # Backward compatibility for previous behavior, maybe we should deprecate it:
+ # If you only set max_length, it activates truncation for max_length
+ if max_length is not None and padding is False and truncation is None:
+ if verbose:
+ if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False):
+ logger.warning(
+ "Truncation was not explicitly activated but `max_length` is provided a specific value, please"
+ " use `truncation=True` to explicitly truncate examples to max length. Defaulting to"
+ " 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the"
+ " tokenizer you can select this strategy more precisely by providing a specific strategy to"
+ " `truncation`."
+ )
+ self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
+ truncation = "longest_first"
+
+ # Get padding strategy
+ if padding is not False:
+ if padding is True:
+ if verbose:
+ if max_length is not None and (
+ truncation is None or truncation is False or truncation == "do_not_truncate"
+ ):
+ warnings.warn(
+ "`max_length` is ignored when `padding`=`True` and there is no truncation strategy. "
+ "To pad to max length, use `padding='max_length'`."
+ )
+ padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
+ elif not isinstance(padding, PaddingStrategy):
+ padding_strategy = PaddingStrategy(padding)
+ elif isinstance(padding, PaddingStrategy):
+ padding_strategy = padding
+ else:
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
+
+ # Get truncation strategy
+ if truncation is not False and truncation is not None:
+ if truncation is True:
+ truncation_strategy = (
+ TruncationStrategy.LONGEST_FIRST
+ ) # Default to truncate the longest sequences in pairs of inputs
+ elif not isinstance(truncation, TruncationStrategy):
+ truncation_strategy = TruncationStrategy(truncation)
+ elif isinstance(truncation, TruncationStrategy):
+ truncation_strategy = truncation
+ else:
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
+
+ # Set max length if needed
+ if max_length is None:
+ if padding_strategy == PaddingStrategy.MAX_LENGTH:
+ if self.model_max_length > LARGE_INTEGER:
+ if verbose:
+ if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False):
+ logger.warning(
+ "Asking to pad to max_length but no maximum length is provided and the model has no"
+ " predefined maximum length. Default to no padding."
+ )
+ self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
+ else:
+ max_length = self.model_max_length
+
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
+ if self.model_max_length > LARGE_INTEGER:
+ if verbose:
+ if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False):
+ logger.warning(
+ "Asking to truncate to max_length but no maximum length is provided and the model has"
+ " no predefined maximum length. Default to no truncation."
+ )
+ self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
+ else:
+ max_length = self.model_max_length
+
+ # Test if we have a padding token
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.pad_token is None or self.pad_token_id < 0):
+ raise ValueError(
+ "Asking to pad but the tokenizer does not have a padding token. "
+ "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
+ "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
+ )
+
+ # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
+ if (
+ truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
+ and padding_strategy != PaddingStrategy.DO_NOT_PAD
+ and pad_to_multiple_of is not None
+ and max_length is not None
+ and (max_length % pad_to_multiple_of != 0)
+ ):
+ raise ValueError(
+ "Truncation and padding are both activated but "
+ f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
+ )
+
+ return padding_strategy, truncation_strategy, max_length, kwargs
+
+ @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput], None] = None,
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None,
+ text_target: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput], None] = None,
+ text_pair_target: Optional[
+ Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]
+ ] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy, None] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ is_split_into_words: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ """
+ Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
+ sequences.
+
+ Args:
+ text (`str`, `list[str]`, `list[list[str]]`, *optional*):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ text_pair (`str`, `list[str]`, `list[list[str]]`, *optional*):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ text_target (`str`, `list[str]`, `list[list[str]]`, *optional*):
+ The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a
+ list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized),
+ you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ text_pair_target (`str`, `list[str]`, `list[list[str]]`, *optional*):
+ The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a
+ list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized),
+ you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ """
+ # To avoid duplicating
+ all_kwargs = {
+ "add_special_tokens": add_special_tokens,
+ "padding": padding,
+ "truncation": truncation,
+ "max_length": max_length,
+ "stride": stride,
+ "is_split_into_words": is_split_into_words,
+ "pad_to_multiple_of": pad_to_multiple_of,
+ "padding_side": padding_side,
+ "return_tensors": return_tensors,
+ "return_token_type_ids": return_token_type_ids,
+ "return_attention_mask": return_attention_mask,
+ "return_overflowing_tokens": return_overflowing_tokens,
+ "return_special_tokens_mask": return_special_tokens_mask,
+ "return_offsets_mapping": return_offsets_mapping,
+ "return_length": return_length,
+ "split_special_tokens": kwargs.pop("split_special_tokens", self.split_special_tokens),
+ "verbose": verbose,
+ }
+
+ if return_tensors in ("tf", "jax"):
+ logger.warning_once(
+ "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We "
+ "recommend migrating to PyTorch classes or pinning your version of Transformers."
+ )
+ all_kwargs.update(kwargs)
+ if text is None and text_target is None:
+ raise ValueError("You need to specify either `text` or `text_target`.")
+ if text is not None:
+ # The context manager will send the inputs as normal texts and not text_target, but we shouldn't change the
+ # input mode in this case.
+ if not self._in_target_context_manager:
+ self._switch_to_input_mode()
+ encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
+ if text_target is not None:
+ self._switch_to_target_mode()
+ target_encodings = self._call_one(text=text_target, text_pair=text_pair_target, **all_kwargs)
+ # Leave back tokenizer in input mode
+ self._switch_to_input_mode()
+
+ if text_target is None:
+ return encodings
+ elif text is None:
+ return target_encodings
+ else:
+ encodings["labels"] = target_encodings["input_ids"]
+ return encodings
+
+ def _call_one(
+ self,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]],
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy, None] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ is_split_into_words: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ split_special_tokens: bool = False,
+ **kwargs,
+ ) -> BatchEncoding:
+ # Input type checking for clearer error
+ def _is_valid_text_input(t):
+ if isinstance(t, str):
+ # Strings are fine
+ return True
+ elif isinstance(t, (list, tuple)):
+ # List are fine as long as they are...
+ if len(t) == 0:
+ # ... empty
+ return True
+ elif isinstance(t[0], str):
+ # ... list of strings
+ return True
+ elif isinstance(t[0], (list, tuple)):
+ # ... list with an empty list or with a list of strings
+ return len(t[0]) == 0 or isinstance(t[0][0], str)
+ else:
+ return False
+ else:
+ return False
+
+ if not _is_valid_text_input(text):
+ raise ValueError(
+ "text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) "
+ "or `list[list[str]]` (batch of pretokenized examples)."
+ )
+
+ if text_pair is not None and not _is_valid_text_input(text_pair):
+ raise ValueError(
+ "text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) "
+ "or `list[list[str]]` (batch of pretokenized examples)."
+ )
+
+ if is_split_into_words:
+ is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
+ else:
+ is_batched = isinstance(text, (list, tuple))
+
+ if is_batched:
+ if isinstance(text_pair, str):
+ raise TypeError(
+ "when tokenizing batches of text, `text_pair` must be a list or tuple with the same length as"
+ " `text`."
+ )
+ if text_pair is not None and len(text) != len(text_pair):
+ raise ValueError(
+ f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
+ f" {len(text_pair)}."
+ )
+ batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
+ return self.batch_encode_plus(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ is_split_into_words=is_split_into_words,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ split_special_tokens=split_special_tokens,
+ **kwargs,
+ )
+ else:
+ return self.encode_plus(
+ text=text,
+ text_pair=text_pair,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ is_split_into_words=is_split_into_words,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ split_special_tokens=split_special_tokens,
+ **kwargs,
+ )
+
+ @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput, EncodedInput],
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy, None] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ is_split_into_words: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ """
+ Tokenize and prepare for the model a sequence or a pair of sequences.
+
+
+
+ This method is deprecated, `__call__` should be used instead.
+
+
+
+ Args:
+ text (`str`, `list[str]` or (for non-fast tokenizers) `list[int]`):
+ The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the
+ `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
+ method).
+ text_pair (`str`, `list[str]` or `list[int]`, *optional*):
+ Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using
+ the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
+ method).
+ """
+
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return self._encode_plus(
+ text=text,
+ text_pair=text_pair,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ is_split_into_words=is_split_into_words,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ split_special_tokens=kwargs.pop("split_special_tokens", self.split_special_tokens),
+ **kwargs,
+ )
+
+ def _encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput, EncodedInput],
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ is_split_into_words: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ split_special_tokens: bool = False,
+ **kwargs,
+ ) -> BatchEncoding:
+ raise NotImplementedError
+
+ @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ list[TextInput],
+ list[TextInputPair],
+ list[PreTokenizedInput],
+ list[PreTokenizedInputPair],
+ list[EncodedInput],
+ list[EncodedInputPair],
+ ],
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy, None] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ is_split_into_words: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ split_special_tokens: bool = False,
+ **kwargs,
+ ) -> BatchEncoding:
+ """
+ Tokenize and prepare for the model a list of sequences or a list of pairs of sequences.
+
+
+
+ This method is deprecated, `__call__` should be used instead.
+
+
+
+ Args:
+ batch_text_or_text_pairs (`list[str]`, `list[tuple[str, str]]`, `list[list[str]]`, `list[tuple[list[str], list[str]]]`, and for not-fast tokenizers, also `list[list[int]]`, `list[tuple[list[int], list[int]]]`):
+ Batch of sequences or pair of sequences to be encoded. This can be a list of
+ string/string-sequences/int-sequences or a list of pair of string/string-sequences/int-sequence (see
+ details in `encode_plus`).
+ """
+
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return self._batch_encode_plus(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ is_split_into_words=is_split_into_words,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ split_special_tokens=split_special_tokens,
+ **kwargs,
+ )
+
+ def _batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ list[TextInput],
+ list[TextInputPair],
+ list[PreTokenizedInput],
+ list[PreTokenizedInputPair],
+ list[EncodedInput],
+ list[EncodedInputPair],
+ ],
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ is_split_into_words: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ split_special_tokens: bool = False,
+ **kwargs,
+ ) -> BatchEncoding:
+ raise NotImplementedError
+
+ def pad(
+ self,
+ encoded_inputs: Union[
+ BatchEncoding,
+ list[BatchEncoding],
+ dict[str, EncodedInput],
+ dict[str, list[EncodedInput]],
+ list[dict[str, EncodedInput]],
+ ],
+ padding: Union[bool, str, PaddingStrategy] = True,
+ max_length: Optional[int] = None,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ verbose: bool = True,
+ ) -> BatchEncoding:
+ """
+ Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
+ in the batch.
+
+ Padding side (left/right) padding token ids are defined at the tokenizer level (with `self.padding_side`,
+ `self.pad_token_id` and `self.pad_token_type_id`).
+
+ Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the
+ text followed by a call to the `pad` method to get a padded encoding.
+
+
+
+ If the `encoded_inputs` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
+ result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of
+ PyTorch tensors, you will lose the specific device of your tensors however.
+
+
+
+ Args:
+ encoded_inputs ([`BatchEncoding`], list of [`BatchEncoding`], `dict[str, list[int]]`, `dict[str, list[list[int]]` or `list[dict[str, list[int]]]`):
+ Tokenized inputs. Can represent one input ([`BatchEncoding`] or `dict[str, list[int]]`) or a batch of
+ tokenized inputs (list of [`BatchEncoding`], *dict[str, list[list[int]]]* or *list[dict[str,
+ list[int]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader
+ collate function.
+
+ Instead of `list[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), see
+ the note above for the return type.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
+ index) among:
+
+ - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different
+ lengths).
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see above).
+ pad_to_multiple_of (`int`, *optional*):
+ If set will pad the sequence to a multiple of the provided value.
+
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
+ `>= 7.5` (Volta).
+ padding_side (`str`, *optional*):
+ The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+ Default value is picked from the class attribute of the same name.
+ return_attention_mask (`bool`, *optional*):
+ Whether to return the attention mask. If left to the default, will return the attention mask according
+ to the specific tokenizer's default, defined by the `return_outputs` attribute.
+
+ [What are attention masks?](../glossary#attention-mask)
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ verbose (`bool`, *optional*, defaults to `True`):
+ Whether or not to print more information and warnings.
+ """
+ if self.__class__.__name__.endswith("Fast"):
+ if not self.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False):
+ logger.warning_advice(
+ f"You're using a {self.__class__.__name__} tokenizer. Please note that with a fast tokenizer,"
+ " using the `__call__` method is faster than using a method to encode the text followed by a call"
+ " to the `pad` method to get a padded encoding."
+ )
+ self.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
+
+ # If we have a list of dicts, let's convert it in a dict of lists
+ # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
+ if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
+ encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0]}
+
+ # The model's main input name, usually `input_ids`, has been passed for padding
+ if self.model_input_names[0] not in encoded_inputs:
+ raise ValueError(
+ "You should supply an encoding or a list of encodings to this method "
+ f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
+ )
+
+ required_input = encoded_inputs[self.model_input_names[0]]
+
+ if required_input is None or (isinstance(required_input, Sized) and len(required_input) == 0):
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = []
+ return encoded_inputs
+
+ # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
+ # and rebuild them afterwards if no return_tensors is specified
+ # Note that we lose the specific device the tensor may be on for PyTorch
+
+ first_element = required_input[0]
+ if isinstance(first_element, (list, tuple)):
+ # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
+ for item in required_input:
+ if len(item) != 0:
+ first_element = item[0]
+ break
+ # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
+ if not isinstance(first_element, (int, list, tuple)):
+ if is_tf_tensor(first_element):
+ return_tensors = "tf" if return_tensors is None else return_tensors
+ elif is_torch_tensor(first_element):
+ return_tensors = "pt" if return_tensors is None else return_tensors
+ elif isinstance(first_element, np.ndarray):
+ return_tensors = "np" if return_tensors is None else return_tensors
+ else:
+ raise ValueError(
+ f"type of {first_element} unknown: {type(first_element)}. "
+ "Should be one of a python, numpy, pytorch or tensorflow object."
+ )
+
+ for key, value in encoded_inputs.items():
+ encoded_inputs[key] = to_py_obj(value)
+
+ # Convert padding_strategy in PaddingStrategy
+ padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
+ padding=padding, max_length=max_length, verbose=verbose
+ )
+
+ required_input = encoded_inputs[self.model_input_names[0]]
+ if required_input and not isinstance(required_input[0], (list, tuple)):
+ encoded_inputs = self._pad(
+ encoded_inputs,
+ max_length=max_length,
+ padding_strategy=padding_strategy,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_attention_mask=return_attention_mask,
+ )
+ return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
+
+ batch_size = len(required_input)
+ assert all(len(v) == batch_size for v in encoded_inputs.values()), (
+ "Some items in the output dictionary have a different batch size than others."
+ )
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = max(len(inputs) for inputs in required_input)
+ padding_strategy = PaddingStrategy.MAX_LENGTH
+
+ batch_outputs = {}
+ for i in range(batch_size):
+ inputs = {k: v[i] for k, v in encoded_inputs.items()}
+ outputs = self._pad(
+ inputs,
+ max_length=max_length,
+ padding_strategy=padding_strategy,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_attention_mask=return_attention_mask,
+ )
+
+ for key, value in outputs.items():
+ if key not in batch_outputs:
+ batch_outputs[key] = []
+ batch_outputs[key].append(value)
+
+ return BatchEncoding(batch_outputs, tensor_type=return_tensors)
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Create the token type IDs corresponding to the sequences passed. [What are token type
+ IDs?](../glossary#token-type-ids)
+
+ Should be overridden in a subclass if the model has a special way of building those.
+
+ Args:
+ token_ids_0 (`list[int]`): The first tokenized sequence.
+ token_ids_1 (`list[int]`, *optional*): The second tokenized sequence.
+
+ Returns:
+ `list[int]`: The token type ids.
+ """
+ cls_len = int(getattr(self, "cls_token_id", None) is not None)
+ sep_len = int(getattr(self, "sep_token_id", None) is not None)
+
+ if token_ids_1 is None:
+ return [0] * (cls_len + len(token_ids_0) + sep_len)
+
+ return [0] * (cls_len + len(token_ids_0) + sep_len) + [1] * (len(token_ids_1) + sep_len)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+ ) -> list[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens.
+
+ This implementation does not add special tokens and this method should be overridden in a subclass.
+
+ Args:
+ token_ids_0 (`list[int]`): The first tokenized sequence.
+ token_ids_1 (`list[int]`, *optional*): The second tokenized sequence.
+
+ Returns:
+ `list[int]`: The model input with special tokens.
+ """
+ if token_ids_1 is None:
+ return token_ids_0
+ return token_ids_0 + token_ids_1
+
+ @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def prepare_for_model(
+ self,
+ ids: list[int],
+ pair_ids: Optional[list[int]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy, None] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ prepend_batch_axis: bool = False,
+ **kwargs,
+ ) -> BatchEncoding:
+ """
+ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
+ adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
+ manages a moving window (with user defined stride) for overflowing tokens. Please Note, for *pair_ids*
+ different than `None` and *truncation_strategy = longest_first* or `True`, it is not possible to return
+ overflowing tokens. Such a combination of arguments will raise an error.
+
+ Args:
+ ids (`list[int]`):
+ Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and
+ `convert_tokens_to_ids` methods.
+ pair_ids (`list[int]`, *optional*):
+ Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize`
+ and `convert_tokens_to_ids` methods.
+ """
+
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ pair = pair_ids is not None
+ len_ids = len(ids)
+ len_pair_ids = len(pair_ids) if pair else 0
+
+ if return_token_type_ids and not add_special_tokens:
+ raise ValueError(
+ "Asking to return token_type_ids while setting add_special_tokens to False "
+ "results in an undefined behavior. Please set add_special_tokens to True or "
+ "set return_token_type_ids to None."
+ )
+
+ if (
+ return_overflowing_tokens
+ and truncation_strategy == TruncationStrategy.LONGEST_FIRST
+ and pair_ids is not None
+ ):
+ raise ValueError(
+ "Not possible to return overflowing tokens for pair of sequences with the "
+ "`longest_first`. Please select another truncation strategy than `longest_first`, "
+ "for instance `only_second` or `only_first`."
+ )
+
+ # Load from model defaults
+ if return_token_type_ids is None:
+ return_token_type_ids = "token_type_ids" in self.model_input_names
+ if return_attention_mask is None:
+ return_attention_mask = "attention_mask" in self.model_input_names
+
+ encoded_inputs = {}
+
+ # Compute the total size of the returned encodings
+ total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
+
+ # Truncation: Handle max sequence length
+ overflowing_tokens = []
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
+ ids, pair_ids, overflowing_tokens = self.truncate_sequences(
+ ids,
+ pair_ids=pair_ids,
+ num_tokens_to_remove=total_len - max_length,
+ truncation_strategy=truncation_strategy,
+ stride=stride,
+ )
+
+ if return_overflowing_tokens:
+ encoded_inputs["overflowing_tokens"] = overflowing_tokens
+ encoded_inputs["num_truncated_tokens"] = total_len - max_length
+
+ # Add special tokens
+ if add_special_tokens:
+ sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
+ token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
+ else:
+ sequence = ids + pair_ids if pair else ids
+ token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])
+
+ # Build output dictionary
+ encoded_inputs["input_ids"] = sequence
+ if return_token_type_ids:
+ encoded_inputs["token_type_ids"] = token_type_ids
+ if return_special_tokens_mask:
+ if add_special_tokens:
+ encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
+ else:
+ encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
+
+ # Check lengths
+ self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
+
+ # Padding
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
+ encoded_inputs = self.pad(
+ encoded_inputs,
+ max_length=max_length,
+ padding=padding_strategy.value,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_attention_mask=return_attention_mask,
+ )
+
+ if return_length:
+ encoded_inputs["length"] = len(encoded_inputs["input_ids"])
+
+ batch_outputs = BatchEncoding(
+ encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
+ )
+
+ return batch_outputs
+
+ def truncate_sequences(
+ self,
+ ids: list[int],
+ pair_ids: Optional[list[int]] = None,
+ num_tokens_to_remove: int = 0,
+ truncation_strategy: Union[str, TruncationStrategy] = "longest_first",
+ stride: int = 0,
+ ) -> tuple[list[int], list[int], list[int]]:
+ """
+ Truncates a sequence pair in-place following the strategy.
+
+ Args:
+ ids (`list[int]`):
+ Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and
+ `convert_tokens_to_ids` methods.
+ pair_ids (`list[int]`, *optional*):
+ Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize`
+ and `convert_tokens_to_ids` methods.
+ num_tokens_to_remove (`int`, *optional*, defaults to 0):
+ Number of tokens to remove using the truncation strategy.
+ truncation_strategy (`str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `'longest_first'`):
+ The strategy to follow for truncation. Can be:
+
+ - `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will truncate
+ token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a
+ batch of pairs) is provided.
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater
+ than the model maximum admissible input size).
+ stride (`int`, *optional*, defaults to 0):
+ If set to a positive number, the overflowing tokens returned will contain some tokens from the main
+ sequence returned. The value of this argument defines the number of additional tokens.
+
+ Returns:
+ `tuple[list[int], list[int], list[int]]`: The truncated `ids`, the truncated `pair_ids` and the list of
+ overflowing tokens. Note: The *longest_first* strategy returns empty list of overflowing tokens if a pair
+ of sequences (or a batch of pairs) is provided.
+ """
+ if num_tokens_to_remove <= 0:
+ return ids, pair_ids, []
+
+ if not isinstance(truncation_strategy, TruncationStrategy):
+ truncation_strategy = TruncationStrategy(truncation_strategy)
+
+ overflowing_tokens = []
+ if truncation_strategy == TruncationStrategy.ONLY_FIRST or (
+ truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None
+ ):
+ if len(ids) > num_tokens_to_remove:
+ window_len = min(len(ids), stride + num_tokens_to_remove)
+ if self.truncation_side == "left":
+ overflowing_tokens = ids[:window_len]
+ ids = ids[num_tokens_to_remove:]
+ elif self.truncation_side == "right":
+ overflowing_tokens = ids[-window_len:]
+ ids = ids[:-num_tokens_to_remove]
+ else:
+ raise ValueError(f"invalid truncation strategy: {self.truncation_side}, use 'left' or 'right'.")
+
+ else:
+ error_msg = (
+ f"We need to remove {num_tokens_to_remove} to truncate the input "
+ f"but the first sequence has a length {len(ids)}. "
+ )
+ if truncation_strategy == TruncationStrategy.ONLY_FIRST:
+ error_msg = (
+ error_msg + "Please select another truncation strategy than "
+ f"{truncation_strategy}, for instance 'longest_first' or 'only_second'."
+ )
+ logger.error(error_msg)
+ elif truncation_strategy == TruncationStrategy.LONGEST_FIRST:
+ logger.warning(
+ "Be aware, overflowing tokens are not returned for the setting you have chosen,"
+ f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' "
+ "truncation strategy. So the returned list will always be empty even if some "
+ "tokens have been removed."
+ )
+ len_pair_ids = len(pair_ids) if pair_ids is not None else 0
+ len_ids = len(ids)
+ first_remove = min(abs(len_pair_ids - len_ids), num_tokens_to_remove)
+ second_remove = num_tokens_to_remove - first_remove
+ if len_ids > len_pair_ids:
+ ids_to_move = first_remove + second_remove // 2
+ pair_ids_to_move = second_remove - second_remove // 2
+ else:
+ ids_to_move = second_remove // 2
+ pair_ids_to_move = first_remove + second_remove - (second_remove // 2)
+
+ if self.truncation_side == "right":
+ ids = ids[:-ids_to_move] if ids_to_move > 0 else ids
+ pair_ids = pair_ids[:-pair_ids_to_move] if pair_ids is not None and pair_ids_to_move > 0 else pair_ids
+ elif self.truncation_side == "left":
+ ids = ids[ids_to_move:]
+ pair_ids = pair_ids[pair_ids_to_move:] if pair_ids is not None else None
+ else:
+ raise ValueError(f"invalid truncation strategy:{self.truncation_side}")
+
+ elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
+ if len(pair_ids) > num_tokens_to_remove:
+ window_len = min(len(pair_ids), stride + num_tokens_to_remove)
+ if self.truncation_side == "right":
+ overflowing_tokens = pair_ids[-window_len:]
+ pair_ids = pair_ids[:-num_tokens_to_remove]
+ elif self.truncation_side == "left":
+ overflowing_tokens = pair_ids[:window_len]
+ pair_ids = pair_ids[num_tokens_to_remove:]
+ else:
+ raise ValueError(f"invalid truncation strategy:{self.truncation_side}")
+ else:
+ logger.error(
+ f"We need to remove {num_tokens_to_remove} to truncate the input "
+ f"but the second sequence has a length {len(pair_ids)}. "
+ f"Please select another truncation strategy than {truncation_strategy}, "
+ "for instance 'longest_first' or 'only_first'."
+ )
+
+ return (ids, pair_ids, overflowing_tokens)
+
+ def _pad(
+ self,
+ encoded_inputs: Union[dict[str, EncodedInput], BatchEncoding],
+ max_length: Optional[int] = None,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_attention_mask: Optional[bool] = None,
+ ) -> dict:
+ """
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
+
+ Args:
+ encoded_inputs:
+ Dictionary of tokenized inputs (`list[int]`) or batch of tokenized inputs (`list[list[int]]`).
+ max_length: maximum length of the returned list and optionally padding length (see below).
+ Will truncate by taking into account the special tokens.
+ padding_strategy: PaddingStrategy to use for padding.
+
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
+ The tokenizer padding sides are defined in `padding_side` argument:
+
+ - 'left': pads on the left of the sequences
+ - 'right': pads on the right of the sequences
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
+ `>= 7.5` (Volta).
+ padding_side:
+ The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+ Default value is picked from the class attribute of the same name.
+ return_attention_mask:
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
+ """
+ # Load from model defaults
+ if return_attention_mask is None:
+ return_attention_mask = "attention_mask" in self.model_input_names
+
+ required_input = encoded_inputs[self.model_input_names[0]]
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = len(required_input)
+
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
+
+ # Initialize attention mask if not present.
+ if return_attention_mask and "attention_mask" not in encoded_inputs:
+ encoded_inputs["attention_mask"] = [1] * len(required_input)
+
+ if needs_to_be_padded:
+ difference = max_length - len(required_input)
+ padding_side = padding_side if padding_side is not None else self.padding_side
+
+ if padding_side == "right":
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = (
+ encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
+ )
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
+ encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
+ elif padding_side == "left":
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
+ "token_type_ids"
+ ]
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
+ else:
+ raise ValueError(f"Invalid padding strategy:{padding_side}")
+
+ return encoded_inputs
+
+ def convert_tokens_to_string(self, tokens: list[str]) -> str:
+ """
+ Converts a sequence of tokens in a single string. The most simple way to do it is `" ".join(tokens)` but we
+ often want to remove sub-word tokenization artifacts at the same time.
+
+ Args:
+ tokens (`list[str]`): The token to join in a string.
+
+ Returns:
+ `str`: The joined tokens.
+ """
+ raise NotImplementedError
+
+ def batch_decode(
+ self,
+ sequences: Union[list[int], list[list[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"],
+ skip_special_tokens: bool = False,
+ clean_up_tokenization_spaces: Optional[bool] = None,
+ **kwargs,
+ ) -> list[str]:
+ """
+ Convert a list of lists of token ids into a list of strings by calling decode.
+
+ Args:
+ sequences (`Union[list[int], list[list[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
+ List of tokenized input ids. Can be obtained using the `__call__` method.
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not to remove special tokens in the decoding.
+ clean_up_tokenization_spaces (`bool`, *optional*):
+ Whether or not to clean up the tokenization spaces. If `None`, will default to
+ `self.clean_up_tokenization_spaces`.
+ kwargs (additional keyword arguments, *optional*):
+ Will be passed to the underlying model specific decode method.
+
+ Returns:
+ `list[str]`: The list of decoded sentences.
+ """
+ return [
+ self.decode(
+ seq,
+ skip_special_tokens=skip_special_tokens,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+ for seq in sequences
+ ]
+
+ def decode(
+ self,
+ token_ids: Union[int, list[int], np.ndarray, "torch.Tensor"],
+ skip_special_tokens: bool = False,
+ clean_up_tokenization_spaces: Optional[bool] = None,
+ **kwargs,
+ ) -> str:
+ """
+ Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
+ tokens and clean up tokenization spaces.
+
+ Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
+
+ Args:
+ token_ids (`Union[int, list[int], np.ndarray, torch.Tensor, tf.Tensor]`):
+ List of tokenized input ids. Can be obtained using the `__call__` method.
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not to remove special tokens in the decoding.
+ clean_up_tokenization_spaces (`bool`, *optional*):
+ Whether or not to clean up the tokenization spaces. If `None`, will default to
+ `self.clean_up_tokenization_spaces`.
+ kwargs (additional keyword arguments, *optional*):
+ Will be passed to the underlying model specific decode method.
+
+ Returns:
+ `str`: The decoded sentence.
+ """
+ # Convert inputs to python lists
+ token_ids = to_py_obj(token_ids)
+
+ return self._decode(
+ token_ids=token_ids,
+ skip_special_tokens=skip_special_tokens,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+
+ def _decode(
+ self,
+ token_ids: Union[int, list[int]],
+ skip_special_tokens: bool = False,
+ clean_up_tokenization_spaces: Optional[bool] = None,
+ **kwargs,
+ ) -> str:
+ raise NotImplementedError
+
+ def get_special_tokens_mask(
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+ ) -> list[int]:
+ """
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
+
+ Args:
+ token_ids_0 (`list[int]`):
+ List of ids of the first sequence.
+ token_ids_1 (`list[int]`, *optional*):
+ List of ids of the second sequence.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ assert already_has_special_tokens and token_ids_1 is None, (
+ "You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
+ "Please use a slow (full python) tokenizer to activate this argument. "
+ "Or set `return_special_tokens_mask=True` when calling the encoding method "
+ "to get the special tokens mask in any tokenizer. "
+ )
+
+ all_special_ids = self.all_special_ids # cache the property
+
+ special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
+
+ return special_tokens_mask
+
+ @staticmethod
+ def clean_up_tokenization(out_string: str) -> str:
+ """
+ Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms.
+
+ Args:
+ out_string (`str`): The text to clean up.
+
+ Returns:
+ `str`: The cleaned-up string.
+ """
+ out_string = (
+ out_string.replace(" .", ".")
+ .replace(" ?", "?")
+ .replace(" !", "!")
+ .replace(" ,", ",")
+ .replace(" ' ", "'")
+ .replace(" n't", "n't")
+ .replace(" 'm", "'m")
+ .replace(" 's", "'s")
+ .replace(" 've", "'ve")
+ .replace(" 're", "'re")
+ )
+ return out_string
+
+ def _eventual_warn_about_too_long_sequence(self, ids: list[int], max_length: Optional[int], verbose: bool):
+ """
+ Depending on the input and internal state we might trigger a warning about a sequence that is too long for its
+ corresponding model
+
+ Args:
+ ids (`list[str]`): The ids produced by the tokenization
+ max_length (`int`, *optional*): The max_length desired (does not trigger a warning if it is set)
+ verbose (`bool`): Whether or not to print more information and warnings.
+
+ """
+ if max_length is None and len(ids) > self.model_max_length and verbose and self.model_max_length != 0:
+ if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False):
+ logger.warning(
+ "Token indices sequence length is longer than the specified maximum sequence length "
+ f"for this model ({len(ids)} > {self.model_max_length}). Running this sequence through the model "
+ "will result in indexing errors"
+ )
+ self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
+
+ def _switch_to_input_mode(self):
+ """
+ Private method to put the tokenizer in input mode (when it has different modes for input/outputs)
+ """
+ pass
+
+ def _switch_to_target_mode(self):
+ """
+ Private method to put the tokenizer in target mode (when it has different modes for input/outputs)
+ """
+ pass
+
+ @contextmanager
+ def as_target_tokenizer(self):
+ """
+ Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
+ sequence-to-sequence models that need a slightly different processing for the labels.
+ """
+ warnings.warn(
+ "`as_target_tokenizer` is deprecated and will be removed in v5 of Transformers. You can tokenize your "
+ "labels by using the argument `text_target` of the regular `__call__` method (either in the same call as "
+ "your input texts if you use the same keyword arguments, or in a separate call."
+ )
+ self._switch_to_target_mode()
+ self._in_target_context_manager = True
+ yield
+ self._in_target_context_manager = False
+ self._switch_to_input_mode()
+
+ @classmethod
+ def register_for_auto_class(cls, auto_class="AutoTokenizer"):
+ """
+ Register this class with a given auto class. This should only be used for custom tokenizers as the ones in the
+ library are already mapped with `AutoTokenizer`.
+
+
+
+ Args:
+ auto_class (`str` or `type`, *optional*, defaults to `"AutoTokenizer"`):
+ The auto class to register this new tokenizer with.
+ """
+ if not isinstance(auto_class, str):
+ auto_class = auto_class.__name__
+
+ import transformers.models.auto as auto_module
+
+ if not hasattr(auto_module, auto_class):
+ raise ValueError(f"{auto_class} is not a valid auto class.")
+
+ cls._auto_class = auto_class
+
+ def prepare_seq2seq_batch(
+ self,
+ src_texts: list[str],
+ tgt_texts: Optional[list[str]] = None,
+ max_length: Optional[int] = None,
+ max_target_length: Optional[int] = None,
+ padding: str = "longest",
+ return_tensors: Optional[str] = None,
+ truncation: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ """
+ Prepare model inputs for translation. For best performance, translate one sentence at a time.
+
+ Arguments:
+ src_texts (`list[str]`):
+ List of documents to summarize or source language texts.
+ tgt_texts (`list`, *optional*):
+ List of summaries or target language texts.
+ max_length (`int`, *optional*):
+ Controls the maximum length for encoder inputs (documents to summarize or source language texts) If
+ left unset or set to `None`, this will use the predefined model maximum length if a maximum length is
+ required by one of the truncation/padding parameters. If the model has no specific maximum input length
+ (like XLNet) truncation/padding to a maximum length will be deactivated.
+ max_target_length (`int`, *optional*):
+ Controls the maximum length of decoder inputs (target language texts or summaries) If left unset or set
+ to `None`, this will use the max_length value.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Activates and controls padding. Accepts the following values:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `True`):
+ Activates and controls truncation. Accepts the following values:
+
+ - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
+ to the maximum acceptable input length for the model if that argument is not provided. This will
+ truncate token by token, removing a token from the longest sequence in the pair if a pair of
+ sequences (or a batch of pairs) is provided.
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
+ greater than the model maximum admissible input size).
+ **kwargs:
+ Additional keyword arguments passed along to `self.__call__`.
+
+ Return:
+ [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to the encoder.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
+ - **labels** -- List of token ids for tgt_texts.
+
+ The full set of keys `[input_ids, attention_mask, labels]`, will only be returned if tgt_texts is passed.
+ Otherwise, input_ids, attention_mask will be the only keys.
+ """
+ # docstyle-ignore
+ formatted_warning = """
+`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of HuggingFace Transformers. Use the regular
+`__call__` method to prepare your inputs and targets.
+
+Here is a short example:
+
+model_inputs = tokenizer(src_texts, text_target=tgt_texts, ...)
+
+If you either need to use different keyword arguments for the source and target texts, you should do two calls like
+this:
+
+model_inputs = tokenizer(src_texts, ...)
+labels = tokenizer(text_target=tgt_texts, ...)
+model_inputs["labels"] = labels["input_ids"]
+
+See the documentation of your specific tokenizer for more details on the specific arguments to the tokenizer of choice.
+For a more complete example, see the implementation of `prepare_seq2seq_batch`.
+"""
+ warnings.warn(formatted_warning, FutureWarning)
+ # mBART-specific kwargs that should be ignored by other models.
+ kwargs.pop("src_lang", None)
+ kwargs.pop("tgt_lang", None)
+ if max_length is None:
+ max_length = self.model_max_length
+ model_inputs = self(
+ src_texts,
+ add_special_tokens=True,
+ return_tensors=return_tensors,
+ max_length=max_length,
+ padding=padding,
+ truncation=truncation,
+ **kwargs,
+ )
+ if tgt_texts is None:
+ return model_inputs
+ # Process tgt_texts
+ if max_target_length is None:
+ max_target_length = max_length
+ with self.as_target_tokenizer():
+ labels = self(
+ tgt_texts,
+ add_special_tokens=True,
+ return_tensors=return_tensors,
+ padding=padding,
+ max_length=max_target_length,
+ truncation=truncation,
+ **kwargs,
+ )
+ model_inputs["labels"] = labels["input_ids"]
+ return model_inputs
+
+
+def get_fast_tokenizer_file(tokenization_files: list[str]) -> str:
+ """
+ Get the tokenization file to use for this version of transformers.
+
+ Args:
+ tokenization_files (`list[str]`): The list of available configuration files.
+
+ Returns:
+ `str`: The tokenization file to use.
+ """
+ tokenizer_files_map = {}
+ for file_name in tokenization_files:
+ search = _re_tokenizer_file.search(file_name)
+ if search is not None:
+ v = search.groups()[0]
+ tokenizer_files_map[v] = file_name
+ available_versions = sorted(tokenizer_files_map.keys())
+
+ # Defaults to FULL_TOKENIZER_FILE and then try to look at some newer versions.
+ tokenizer_file = FULL_TOKENIZER_FILE
+ transformers_version = version.parse(__version__)
+ for v in available_versions:
+ if version.parse(v) <= transformers_version:
+ tokenizer_file = tokenizer_files_map[v]
+ else:
+ # No point going further since the versions are sorted.
+ break
+
+ return tokenizer_file
+
+
+# To update the docstring, we need to copy the method, otherwise we change the original docstring.
+PreTrainedTokenizerBase.push_to_hub = copy_func(PreTrainedTokenizerBase.push_to_hub)
+if PreTrainedTokenizerBase.push_to_hub.__doc__ is not None:
+ PreTrainedTokenizerBase.push_to_hub.__doc__ = PreTrainedTokenizerBase.push_to_hub.__doc__.format(
+ object="tokenizer", object_class="AutoTokenizer", object_files="tokenizer files"
+ )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tokenization_utils_fast.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tokenization_utils_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe4873d61b378410c90db80b4b19094f3c6ab292
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tokenization_utils_fast.py
@@ -0,0 +1,922 @@
+# Copyright 2020 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Tokenization classes for fast tokenizers (provided by HuggingFace's tokenizers library). For slow (python) tokenizers
+see tokenization_utils.py
+"""
+
+import copy
+import json
+import os
+from collections import defaultdict
+from collections.abc import Iterable
+from typing import Any, Optional, Union
+
+import tokenizers.pre_tokenizers as pre_tokenizers_fast
+from tokenizers import Encoding as EncodingFast
+from tokenizers import Tokenizer as TokenizerFast
+from tokenizers.decoders import Decoder as DecoderFast
+from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordLevelTrainer, WordPieceTrainer
+
+from .convert_slow_tokenizer import convert_slow_tokenizer
+from .integrations.ggml import convert_gguf_tokenizer
+from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
+from .tokenization_utils import PreTrainedTokenizer
+from .tokenization_utils_base import (
+ INIT_TOKENIZER_DOCSTRING,
+ AddedToken,
+ BatchEncoding,
+ PreTokenizedInput,
+ PreTokenizedInputPair,
+ PreTrainedTokenizerBase,
+ SpecialTokensMixin,
+ TextInput,
+ TextInputPair,
+ TruncationStrategy,
+)
+from .utils import PaddingStrategy, add_end_docstrings, logging
+
+
+logger = logging.get_logger(__name__)
+
+# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
+TOKENIZER_FILE = "tokenizer.json"
+SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
+TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
+TIKTOKEN_VOCAB_FILE = "tokenizer.model"
+
+# Slow tokenizers have an additional added tokens files
+ADDED_TOKENS_FILE = "added_tokens.json"
+
+INIT_TOKENIZER_DOCSTRING += """
+ tokenizer_object ([`tokenizers.Tokenizer`]):
+ A [`tokenizers.Tokenizer`] object from 🤗 tokenizers to instantiate from. See [Using tokenizers from 🤗
+ tokenizers](../fast_tokenizers) for more information.
+ tokenizer_file ([`str`]):
+ A path to a local JSON file representing a previously serialized [`tokenizers.Tokenizer`] object from 🤗
+ tokenizers.
+"""
+
+MODEL_TO_TRAINER_MAPPING = {
+ "BPE": BpeTrainer,
+ "Unigram": UnigramTrainer,
+ "WordLevel": WordLevelTrainer,
+ "WordPiece": WordPieceTrainer,
+}
+
+VOCAB_FILES_NAMES = {"tokenizer_file": TOKENIZER_FILE, "vocab_file": TIKTOKEN_VOCAB_FILE}
+
+
+@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
+class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
+ """
+ Base class for all fast tokenizers (wrapping HuggingFace tokenizers library).
+
+ Inherits from [`~tokenization_utils_base.PreTrainedTokenizerBase`].
+
+ Handles all the shared methods for tokenization and special tokens, as well as methods for
+ downloading/caching/loading pretrained tokenizers, as well as adding tokens to the vocabulary.
+
+ This class also contains the added tokens in a unified way on top of all tokenizers so we don't have to handle the
+ specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ slow_tokenizer_class: Optional[type[PreTrainedTokenizer]] = None
+
+ def __init__(self, *args, **kwargs):
+ tokenizer_object = kwargs.pop("tokenizer_object", None)
+ slow_tokenizer = kwargs.pop("__slow_tokenizer", None)
+ gguf_file = kwargs.pop("gguf_file", None)
+ fast_tokenizer_file = kwargs.pop("tokenizer_file", None)
+ from_slow = kwargs.pop("from_slow", False)
+ added_tokens_decoder = kwargs.pop("added_tokens_decoder", {})
+ self.add_prefix_space = kwargs.get("add_prefix_space", False)
+
+ if from_slow and slow_tokenizer is None and self.slow_tokenizer_class is None:
+ raise ValueError(
+ "Cannot instantiate this tokenizer from a slow version. If it's based on sentencepiece, make sure you "
+ "have sentencepiece installed."
+ )
+
+ if tokenizer_object is not None:
+ fast_tokenizer = copy.deepcopy(tokenizer_object)
+ elif fast_tokenizer_file is not None and not from_slow:
+ # We have a serialization from tokenizers which let us directly build the backend
+ fast_tokenizer = TokenizerFast.from_file(fast_tokenizer_file)
+ elif slow_tokenizer:
+ # We need to convert a slow tokenizer to build the backend
+ fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)
+ elif gguf_file is not None:
+ # We need to convert a slow tokenizer to build the backend
+ gguf_param = load_gguf_checkpoint(kwargs.get("vocab_file"))
+ architecture = gguf_param["config"]["model_type"]
+ tokenizer_dict = gguf_param["tokenizer"]
+ tokenizer_config = gguf_param["tokenizer_config"]
+ fast_tokenizer, additional_kwargs = convert_gguf_tokenizer(architecture, tokenizer_dict)
+ kwargs.update(tokenizer_config)
+ if len(additional_kwargs) > 0:
+ kwargs.update(additional_kwargs)
+ elif self.slow_tokenizer_class is not None and slow_tokenizer is not False:
+ # We need to create and convert a slow tokenizer to build the backend
+ slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs)
+ fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)
+ elif not slow_tokenizer:
+ # We tried loading a slow_tokenizer with spm and failed, try to load with tiktoken
+ self.vocab_file = kwargs.get("vocab_file")
+ self.additional_special_tokens = kwargs.get("additional_special_tokens", [])
+ fast_tokenizer = convert_slow_tokenizer(self, from_tiktoken=True)
+ slow_tokenizer = None
+ else:
+ raise ValueError(
+ "Couldn't instantiate the backend tokenizer from one of: \n"
+ "(1) a `tokenizers` library serialization file, \n"
+ "(2) a slow tokenizer instance to convert or \n"
+ "(3) an equivalent slow tokenizer class to instantiate and convert. \n"
+ "You need to have sentencepiece or tiktoken installed to convert a slow tokenizer to a fast one."
+ )
+
+ self._tokenizer = fast_tokenizer
+
+ if slow_tokenizer is not None:
+ kwargs.update(slow_tokenizer.init_kwargs)
+
+ self._decode_use_source_tokenizer = False
+
+ _truncation = self._tokenizer.truncation
+
+ if _truncation is not None:
+ self._tokenizer.enable_truncation(**_truncation)
+ kwargs.setdefault("max_length", _truncation["max_length"])
+ kwargs.setdefault("truncation_side", _truncation["direction"])
+ kwargs.setdefault("stride", _truncation["stride"])
+ kwargs.setdefault("truncation_strategy", _truncation["strategy"])
+ else:
+ self._tokenizer.no_truncation()
+
+ _padding = self._tokenizer.padding
+ if _padding is not None:
+ self._tokenizer.enable_padding(**_padding)
+ kwargs.setdefault("pad_token", _padding["pad_token"])
+ kwargs.setdefault("pad_token_type_id", _padding["pad_type_id"])
+ kwargs.setdefault("padding_side", _padding["direction"])
+ kwargs.setdefault("max_length", _padding["length"])
+ kwargs.setdefault("pad_to_multiple_of", _padding["pad_to_multiple_of"])
+
+ # We call this after having initialized the backend tokenizer because we update it.
+ super().__init__(**kwargs)
+ self._tokenizer.encode_special_tokens = self.split_special_tokens
+
+ added_tokens_decoder_hash = {hash(repr(token)) for token in self.added_tokens_decoder}
+ tokens_to_add = [
+ token
+ for index, token in sorted(added_tokens_decoder.items(), key=lambda x: x[0])
+ if hash(repr(token)) not in added_tokens_decoder_hash
+ ]
+ encoder = list(self.added_tokens_encoder.keys()) + [str(token) for token in tokens_to_add]
+ # if some of the special tokens are strings, we check if we don't already have a token
+ tokens_to_add += [
+ token for token in self.all_special_tokens_extended if token not in encoder and token not in tokens_to_add
+ ]
+
+ if len(tokens_to_add) > 0:
+ tokens = []
+ special_tokens = self.all_special_tokens
+ for token in tokens_to_add:
+ is_special = (
+ (token.special or str(token) in special_tokens)
+ if isinstance(token, AddedToken)
+ else str(token) in special_tokens
+ )
+ if isinstance(token, str):
+ token = AddedToken(token, special=is_special)
+ else:
+ token.special = is_special
+ tokens.append(token)
+ if tokens:
+ self.add_tokens(tokens)
+
+ try:
+ pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
+ if pre_tok_state.get("add_prefix_space", self.add_prefix_space) != self.add_prefix_space:
+ pre_tok_class = getattr(pre_tokenizers_fast, pre_tok_state.pop("type"))
+ pre_tok_state["add_prefix_space"] = self.add_prefix_space
+ self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
+ except Exception:
+ # We'll get an error if there is no pre_tokenizer, or if it's a custom pre_tokenizer that can
+ # not be serialized. In those cases, we just ignore the error as there's no pre_tokenizer
+ # for which we need to update the `add_prefix_space` attribute.
+ pass
+
+ @property
+ def is_fast(self) -> bool:
+ return True
+
+ @property
+ def can_save_slow_tokenizer(self) -> bool:
+ """
+ `bool`: Whether or not the slow tokenizer can be saved. For a sentencepiece based slow tokenizer, this
+ can only be `True` if the original `"sentencepiece.model"` was not deleted.
+ """
+ if "vocab_file" in self.vocab_files_names and self.vocab_files_names["vocab_file"].endswith(".model"):
+ if hasattr(self, "vocab_file") and self.vocab_file:
+ # If the vocab file is a sentencepiece model, we can save it
+ return os.path.isfile(self.vocab_file)
+ return False
+ else:
+ return True
+
+ @property
+ def vocab_size(self) -> int:
+ """
+ `int`: Size of the base vocabulary (without the added tokens).
+ """
+ return self._tokenizer.get_vocab_size(with_added_tokens=False)
+
+ def get_vocab(self) -> dict[str, int]:
+ return self._tokenizer.get_vocab(with_added_tokens=True)
+
+ @property
+ def vocab(self) -> dict[str, int]:
+ return self.get_vocab()
+
+ @property
+ def added_tokens_encoder(self) -> dict[str, int]:
+ """
+ Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
+ optimisation in `self._added_tokens_encoder` for the slow tokenizers.
+ """
+ return {k.content: v for v, k in sorted(self.added_tokens_decoder.items(), key=lambda item: item[0])}
+
+ @property
+ def added_tokens_decoder(self) -> dict[int, AddedToken]:
+ """
+ Returns the added tokens in the vocabulary as a dictionary of index to AddedToken.
+
+ Returns:
+ `dict[str, int]`: The added tokens.
+ """
+ return self._tokenizer.get_added_tokens_decoder()
+
+ def get_added_vocab(self) -> dict[str, int]:
+ """
+ Returns the added tokens in the vocabulary as a dictionary of token to index.
+
+ Returns:
+ `dict[str, int]`: The added tokens.
+ """
+ return {k.content: v for v, k in sorted(self.added_tokens_decoder.items(), key=lambda item: item[0])}
+
+ def __bool__(self) -> bool:
+ """
+ Returns True, to avoid expensive `assert tokenizer` gotchas.
+ """
+ return True
+
+ def __len__(self) -> int:
+ """
+ Size of the full vocabulary with the added tokens.
+ """
+ return self._tokenizer.get_vocab_size(with_added_tokens=True)
+
+ @property
+ def backend_tokenizer(self) -> TokenizerFast:
+ """
+ `tokenizers.implementations.BaseTokenizer`: The Rust tokenizer used as a backend.
+ """
+ return self._tokenizer
+
+ @property
+ def decoder(self) -> DecoderFast:
+ """
+ `tokenizers.decoders.Decoder`: The Rust decoder for this tokenizer.
+ """
+ return self._tokenizer.decoder
+
+ def _convert_encoding(
+ self,
+ encoding: EncodingFast,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ ) -> tuple[dict[str, Any], list[EncodingFast]]:
+ """
+ Convert the encoding representation (from low-level HuggingFace tokenizer output) to a python Dict and a list
+ of encodings, take care of building a batch from overflowing tokens.
+
+ Overflowing tokens are converted to additional examples (like batches) so the output values of the dict are
+ lists (overflows) of lists (tokens).
+
+ Output shape: (overflows, sequence length)
+ """
+ if return_token_type_ids is None:
+ return_token_type_ids = "token_type_ids" in self.model_input_names
+ if return_attention_mask is None:
+ return_attention_mask = "attention_mask" in self.model_input_names
+
+ if return_overflowing_tokens and encoding.overflowing is not None:
+ encodings = [encoding] + encoding.overflowing
+ else:
+ encodings = [encoding]
+
+ encoding_dict = defaultdict(list)
+ for e in encodings:
+ encoding_dict["input_ids"].append(e.ids)
+
+ if return_token_type_ids:
+ encoding_dict["token_type_ids"].append(e.type_ids)
+ if return_attention_mask:
+ encoding_dict["attention_mask"].append(e.attention_mask)
+ if return_special_tokens_mask:
+ encoding_dict["special_tokens_mask"].append(e.special_tokens_mask)
+ if return_offsets_mapping:
+ encoding_dict["offset_mapping"].append(e.offsets)
+ if return_length:
+ encoding_dict["length"].append(len(e.ids))
+
+ return encoding_dict, encodings
+
+ def convert_tokens_to_ids(self, tokens: Union[str, Iterable[str]]) -> Union[int, list[int]]:
+ """
+ Converts a token string (or a sequence of tokens) in a single integer id (or a Iterable of ids), using the
+ vocabulary.
+
+ Args:
+ tokens (`str` or `Iterable[str]`): One or several token(s) to convert to token id(s).
+
+ Returns:
+ `int` or `list[int]`: The token id or list of token ids.
+ """
+ if isinstance(tokens, str):
+ return self._convert_token_to_id_with_added_voc(tokens)
+
+ return [self._convert_token_to_id_with_added_voc(token) for token in tokens]
+
+ def _convert_token_to_id_with_added_voc(self, token: str) -> int:
+ index = self._tokenizer.token_to_id(token)
+ if index is None:
+ return self.unk_token_id
+ return index
+
+ def _convert_id_to_token(self, index: int) -> Optional[str]:
+ return self._tokenizer.id_to_token(int(index))
+
+ def _add_tokens(self, new_tokens: list[Union[str, AddedToken]], special_tokens=False) -> int:
+ if special_tokens:
+ return self._tokenizer.add_special_tokens(new_tokens)
+
+ return self._tokenizer.add_tokens(new_tokens)
+
+ def num_special_tokens_to_add(self, pair: bool = False) -> int:
+ """
+ Returns the number of added tokens when encoding a sequence with special tokens.
+
+
+
+ This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not put
+ this inside your training loop.
+
+
+
+ Args:
+ pair (`bool`, *optional*, defaults to `False`):
+ Whether the number of added tokens should be computed in the case of a sequence pair or a single
+ sequence.
+
+ Returns:
+ `int`: Number of special tokens added to sequences.
+ """
+ return self._tokenizer.num_special_tokens_to_add(pair)
+
+ def convert_ids_to_tokens(
+ self, ids: Union[int, list[int]], skip_special_tokens: bool = False
+ ) -> Union[str, list[str]]:
+ """
+ Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
+ added tokens.
+
+ Args:
+ ids (`int` or `list[int]`):
+ The token id (or token ids) to convert to tokens.
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not to remove special tokens in the decoding.
+
+ Returns:
+ `str` or `list[str]`: The decoded token(s).
+ """
+ if isinstance(ids, int):
+ return self._tokenizer.id_to_token(ids)
+ tokens = []
+ # self.all_special_ids is an @property which may be slow, so only compute it once before the loop
+ ids_to_skip = set(self.all_special_ids) if skip_special_tokens else set()
+ for index in ids:
+ index = int(index)
+ if index in ids_to_skip:
+ continue
+ tokens.append(self._tokenizer.id_to_token(index))
+ return tokens
+
+ def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> list[str]:
+ return self.encode_plus(text=text, text_pair=pair, add_special_tokens=add_special_tokens, **kwargs).tokens()
+
+ def set_truncation_and_padding(
+ self,
+ padding_strategy: PaddingStrategy,
+ truncation_strategy: TruncationStrategy,
+ max_length: int,
+ stride: int,
+ pad_to_multiple_of: Optional[int],
+ padding_side: Optional[str],
+ ):
+ """
+ Define the truncation and the padding strategies for fast tokenizers (provided by HuggingFace tokenizers
+ library) and restore the tokenizer settings afterwards.
+
+ The provided tokenizer has no padding / truncation strategy before the managed section. If your tokenizer set a
+ padding / truncation strategy before, then it will be reset to no padding / truncation when exiting the managed
+ section.
+
+ Args:
+ padding_strategy ([`~utils.PaddingStrategy`]):
+ The kind of padding that will be applied to the input
+ truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`]):
+ The kind of truncation that will be applied to the input
+ max_length (`int`):
+ The maximum size of a sequence.
+ stride (`int`):
+ The stride to use when handling overflow.
+ pad_to_multiple_of (`int`, *optional*):
+ If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
+ the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
+ padding_side (`str`, *optional*):
+ The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+ Default value is picked from the class attribute of the same name.
+ """
+ _truncation = self._tokenizer.truncation
+ _padding = self._tokenizer.padding
+ # Set truncation and padding on the backend tokenizer
+ if truncation_strategy == TruncationStrategy.DO_NOT_TRUNCATE:
+ if _truncation is not None:
+ self._tokenizer.no_truncation()
+ else:
+ target = {
+ "max_length": max_length,
+ "stride": stride,
+ "strategy": truncation_strategy.value,
+ "direction": self.truncation_side,
+ }
+
+ # _truncation might contain more keys that the target `transformers`
+ # supports. Use only the target keys to trigger `enable_truncation`.
+ # This should enable this code to works on various `tokenizers`
+ # targets.
+ if _truncation is None:
+ current = None
+ else:
+ current = {k: _truncation.get(k, None) for k in target}
+
+ if current != target:
+ self._tokenizer.enable_truncation(**target)
+
+ if padding_strategy == PaddingStrategy.DO_NOT_PAD:
+ if _padding is not None:
+ self._tokenizer.no_padding()
+ else:
+ length = max_length if padding_strategy == PaddingStrategy.MAX_LENGTH else None
+ target = {
+ "length": length,
+ "direction": padding_side if padding_side is not None else self.padding_side,
+ "pad_id": self.pad_token_id,
+ "pad_token": self.pad_token,
+ "pad_type_id": self.pad_token_type_id,
+ "pad_to_multiple_of": pad_to_multiple_of,
+ }
+ if _padding != target:
+ self._tokenizer.enable_padding(**target)
+
+ def _batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ list[TextInput], list[TextInputPair], list[PreTokenizedInput], list[PreTokenizedInputPair]
+ ],
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ is_split_into_words: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[str] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ split_special_tokens: bool = False,
+ ) -> BatchEncoding:
+ if not isinstance(batch_text_or_text_pairs, (tuple, list)):
+ raise TypeError(
+ f"batch_text_or_text_pairs has to be a list or a tuple (got {type(batch_text_or_text_pairs)})"
+ )
+
+ # Set the truncation and padding strategy and restore the initial configuration
+ self.set_truncation_and_padding(
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ )
+
+ if self._tokenizer.encode_special_tokens != split_special_tokens:
+ self._tokenizer.encode_special_tokens = split_special_tokens
+
+ encodings = self._tokenizer.encode_batch(
+ batch_text_or_text_pairs,
+ add_special_tokens=add_special_tokens,
+ is_pretokenized=is_split_into_words,
+ )
+
+ # Convert encoding to dict
+ # `Tokens` has type: tuple[
+ # list[dict[str, list[list[int]]]] or list[dict[str, 2D-Tensor]],
+ # list[EncodingFast]
+ # ]
+ # with nested dimensions corresponding to batch, overflows, sequence length
+ tokens_and_encodings = [
+ self._convert_encoding(
+ encoding=encoding,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ )
+ for encoding in encodings
+ ]
+
+ # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension
+ # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length)
+ # (we say ~ because the number of overflow varies with the example in the batch)
+ #
+ # To match each overflowing sample with the original sample in the batch
+ # we add an overflow_to_sample_mapping array (see below)
+ sanitized_tokens = {}
+ for key in tokens_and_encodings[0][0]:
+ stack = [e for item, _ in tokens_and_encodings for e in item[key]]
+ sanitized_tokens[key] = stack
+ sanitized_encodings = [e for _, item in tokens_and_encodings for e in item]
+
+ # If returning overflowing tokens, we need to return a mapping
+ # from the batch idx to the original sample
+ if return_overflowing_tokens:
+ overflow_to_sample_mapping = []
+ for i, (toks, _) in enumerate(tokens_and_encodings):
+ overflow_to_sample_mapping += [i] * len(toks["input_ids"])
+ sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
+
+ for input_ids in sanitized_tokens["input_ids"]:
+ self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
+ return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors)
+
+ def _encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[Union[TextInput, PreTokenizedInput]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ is_split_into_words: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ padding_side: Optional[str] = None,
+ return_tensors: Optional[bool] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ split_special_tokens: bool = False,
+ **kwargs,
+ ) -> BatchEncoding:
+ batched_input = [(text, text_pair)] if text_pair else [text]
+ batched_output = self._batch_encode_plus(
+ batched_input,
+ is_split_into_words=is_split_into_words,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ padding_side=padding_side,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ split_special_tokens=split_special_tokens,
+ **kwargs,
+ )
+
+ # Return tensor is None, then we can remove the leading batch axis
+ # Overflowing tokens are returned as a batch of output so we keep them in this case
+ if return_tensors is None and not return_overflowing_tokens:
+ batched_output = BatchEncoding(
+ {
+ key: (value[0] if len(value) > 0 and isinstance(value[0], list) else value)
+ for key, value in batched_output.items()
+ },
+ batched_output.encodings,
+ )
+
+ self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose)
+
+ return batched_output
+
+ def convert_tokens_to_string(self, tokens: list[str]) -> str:
+ return (
+ self.backend_tokenizer.decoder.decode(tokens)
+ if self.backend_tokenizer.decoder is not None
+ else " ".join(tokens)
+ )
+
+ def _decode(
+ self,
+ token_ids: Union[int, list[int]],
+ skip_special_tokens: bool = False,
+ clean_up_tokenization_spaces: Optional[bool] = None,
+ **kwargs,
+ ) -> str:
+ self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
+
+ if isinstance(token_ids, int):
+ token_ids = [token_ids]
+ text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
+
+ clean_up_tokenization_spaces = (
+ clean_up_tokenization_spaces
+ if clean_up_tokenization_spaces is not None
+ else self.clean_up_tokenization_spaces
+ )
+ if clean_up_tokenization_spaces:
+ clean_text = self.clean_up_tokenization(text)
+ return clean_text
+ else:
+ return text
+
+ def _save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ file_names: tuple[str, ...],
+ legacy_format: Optional[bool] = None,
+ filename_prefix: Optional[str] = None,
+ ) -> tuple[str, ...]:
+ """
+ Save a tokenizer using the slow-tokenizer/legacy format: vocabulary + added tokens as well as in a unique JSON
+ file containing {config + vocab + added-tokens}.
+ """
+ save_directory = str(save_directory)
+
+ if self.slow_tokenizer_class is None and legacy_format is True:
+ raise ValueError(
+ "Your tokenizer does not have a legacy version defined and therefore cannot register this version. You"
+ " might consider leaving the legacy_format at `None` or setting it to `False`."
+ )
+
+ save_slow = (
+ (legacy_format is None or legacy_format is True)
+ and self.slow_tokenizer_class is not None
+ and self.can_save_slow_tokenizer
+ )
+ save_fast = legacy_format is None or legacy_format is False
+
+ if save_slow:
+ added_tokens_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE
+ )
+ # make sure to be forward compatible
+ added_vocab = {tok: index for tok, index in self.added_tokens_encoder.items() if index >= self.vocab_size}
+ if added_vocab:
+ with open(added_tokens_file, "w", encoding="utf-8") as f:
+ out_str = json.dumps(added_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
+ f.write(out_str)
+
+ vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix)
+ file_names = file_names + vocab_files + (added_tokens_file,)
+
+ if save_fast:
+ tokenizer_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_FILE
+ )
+ self.backend_tokenizer.save(tokenizer_file)
+ file_names = file_names + (tokenizer_file,)
+
+ return file_names
+
+ def train_new_from_iterator(
+ self,
+ text_iterator,
+ vocab_size,
+ length=None,
+ new_special_tokens=None,
+ special_tokens_map=None,
+ **kwargs,
+ ):
+ """
+ Trains a tokenizer on a new corpus with the same defaults (in terms of special tokens or tokenization pipeline)
+ as the current one.
+
+ Args:
+ text_iterator (generator of `list[str]`):
+ The training corpus. Should be a generator of batches of texts, for instance a list of lists of texts
+ if you have everything in memory.
+ vocab_size (`int`):
+ The size of the vocabulary you want for your tokenizer.
+ length (`int`, *optional*):
+ The total number of sequences in the iterator. This is used to provide meaningful progress tracking
+ new_special_tokens (list of `str` or `AddedToken`, *optional*):
+ A list of new special tokens to add to the tokenizer you are training.
+ special_tokens_map (`dict[str, str]`, *optional*):
+ If you want to rename some of the special tokens this tokenizer uses, pass along a mapping old special
+ token name to new special token name in this argument.
+ kwargs (`dict[str, Any]`, *optional*):
+ Additional keyword arguments passed along to the trainer from the 🤗 Tokenizers library.
+
+ Returns:
+ [`PreTrainedTokenizerFast`]: A new tokenizer of the same type as the original one, trained on
+ `text_iterator`.
+
+ """
+ tokenizer_json = json.loads(self._tokenizer.to_str())
+ # Remove added tokens for now (uses IDs of tokens)
+ added_tokens = tokenizer_json.pop("added_tokens")
+ # Remove post processor for now (uses IDs of tokens)
+ post_processor = tokenizer_json.pop("post_processor")
+
+ unk_token = None
+ # Remove vocab
+ if tokenizer_json["model"]["type"] == "BPE":
+ tokenizer_json["model"]["vocab"] = {}
+ tokenizer_json["model"]["merges"] = []
+ elif tokenizer_json["model"]["type"] == "Unigram":
+ if tokenizer_json["model"]["unk_id"] is not None:
+ unk_id = tokenizer_json["model"]["unk_id"]
+ unk_token = tokenizer_json["model"]["vocab"][unk_id][0]
+ if special_tokens_map is not None and unk_token in special_tokens_map:
+ unk_token = special_tokens_map[unk_token]
+ tokenizer_json["model"]["unk_id"] = 0
+ tokenizer_json["model"]["vocab"] = [[unk_token, 0.0]]
+ elif tokenizer_json["model"]["type"] in ["WordLevel", "WordPiece"]:
+ tokenizer_json["model"]["vocab"] = {}
+ else:
+ raise ValueError(
+ f"This method does not support this type of tokenizer (found {tokenizer_json['model']['type']}) "
+ "only BPE, Unigram, WordLevel and WordPiece."
+ )
+
+ if (
+ special_tokens_map is not None
+ and "unk_token" in tokenizer_json["model"]
+ and tokenizer_json["model"]["unk_token"] in special_tokens_map
+ ):
+ tokenizer_json["model"]["unk_token"] = special_tokens_map[tokenizer_json["model"]["unk_token"]]
+
+ tokenizer = TokenizerFast.from_str(json.dumps(tokenizer_json))
+
+ # Get the special tokens from the current tokenizer if none are specified.
+ special_tokens = []
+ for added_token in added_tokens:
+ special = added_token.pop("special", None)
+ _ = added_token.pop("id", None)
+ if tokenizer_json["model"]["type"] != "Unigram" and not special:
+ continue
+ if special_tokens_map is not None and added_token["content"] in special_tokens_map:
+ added_token["content"] = special_tokens_map[added_token["content"]]
+ special_tokens.append(AddedToken(**added_token))
+
+ if new_special_tokens is not None:
+ special_tokens.extend(new_special_tokens)
+
+ # Trainer needs to know the end of word / continuing subword thingies in BPE
+ if (
+ tokenizer_json["model"]["type"] == "BPE"
+ and "continuing_subword_prefix" not in kwargs
+ and tokenizer_json["model"]["continuing_subword_prefix"] is not None
+ ):
+ kwargs["continuing_subword_prefix"] = tokenizer_json["model"]["continuing_subword_prefix"]
+ if (
+ tokenizer_json["model"]["type"] == "BPE"
+ and "end_of_word_suffix" not in kwargs
+ and tokenizer_json["model"]["end_of_word_suffix"] is not None
+ ):
+ kwargs["end_of_word_suffix"] = tokenizer_json["model"]["end_of_word_suffix"]
+ if tokenizer_json["model"]["type"] == "Unigram" and unk_token is not None:
+ kwargs["unk_token"] = unk_token
+ if tokenizer_json["pre_tokenizer"] is not None:
+ if (
+ tokenizer_json["pre_tokenizer"]["type"] == "ByteLevel"
+ or tokenizer_json["pre_tokenizer"]["type"] == "Sequence"
+ and "pretokenizers" in tokenizer_json["pre_tokenizer"]
+ and any(
+ pretokenizer["type"] == "ByteLevel"
+ for pretokenizer in tokenizer_json["pre_tokenizer"]["pretokenizers"]
+ )
+ ):
+ kwargs["initial_alphabet"] = pre_tokenizers_fast.ByteLevel.alphabet()
+
+ trainer_class = MODEL_TO_TRAINER_MAPPING[tokenizer_json["model"]["type"]]
+ trainer = trainer_class(vocab_size=vocab_size, special_tokens=special_tokens, **kwargs)
+ tokenizer.train_from_iterator(text_iterator, length=length, trainer=trainer)
+
+ if post_processor is not None:
+ trained_tokenizer_json = json.loads(tokenizer.to_str())
+ # Almost done, we just have to adjust the token IDs in the post processor
+ if "special_tokens" in post_processor:
+ for key in post_processor["special_tokens"]:
+ tokens = post_processor["special_tokens"][key]["tokens"]
+ if special_tokens_map is not None:
+ tokens = [special_tokens_map.get(token, token) for token in tokens]
+ post_processor["special_tokens"][key]["tokens"] = tokens
+ for token in tokens:
+ token_id = tokenizer.token_to_id(token)
+ if token_id is None:
+ raise ValueError(
+ "Attempted to set a token in the post processor that does not exist in the mapping"
+ )
+
+ post_processor["special_tokens"][key]["ids"] = [tokenizer.token_to_id(token) for token in tokens]
+
+ for special_token in ["cls", "sep"]:
+ if special_token in post_processor:
+ token, _ = post_processor[special_token]
+ if special_tokens_map is not None and token in special_tokens_map:
+ token = special_tokens_map[token]
+ token_id = tokenizer.token_to_id(token)
+ if token_id is None:
+ raise ValueError(
+ "Attempted to set a token in the post processor that does not exist in the mapping"
+ )
+ post_processor[special_token] = [token, token_id]
+
+ trained_tokenizer_json["post_processor"] = post_processor
+ tokenizer = TokenizerFast.from_str(json.dumps(trained_tokenizer_json))
+
+ kwargs = self.init_kwargs.copy()
+ # Map pad/cls/mask token at the Transformers level
+ special_tokens_list = SpecialTokensMixin.SPECIAL_TOKENS_ATTRIBUTES.copy()
+ special_tokens_list.remove("additional_special_tokens")
+ for token in special_tokens_list:
+ if getattr(self, token) is not None:
+ special_token = getattr(self, token)
+ if special_tokens_map is not None and special_token in special_tokens_map:
+ special_token = special_tokens_map[special_token]
+
+ special_token_full = self._special_tokens_map.get(token, None)
+ if isinstance(special_token_full, AddedToken):
+ # Create an added token with the same parameters except the content
+ kwargs[token] = AddedToken(
+ special_token,
+ single_word=special_token_full.single_word,
+ lstrip=special_token_full.lstrip,
+ rstrip=special_token_full.rstrip,
+ normalized=special_token_full.normalized,
+ special=True,
+ )
+ else:
+ kwargs[token] = special_token
+
+ additional_special_tokens = self.additional_special_tokens
+ if new_special_tokens is not None:
+ additional_special_tokens.extend(new_special_tokens)
+ if len(additional_special_tokens) > 0:
+ kwargs["additional_special_tokens"] = additional_special_tokens
+
+ return self.__class__(tokenizer_object=tokenizer, **kwargs)
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/trainer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bfb72b4eaf52dbf37b3bfdcab10531d484ddf13
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/trainer.py
@@ -0,0 +1,5723 @@
+# Copyright 2020-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
+"""
+
+import contextlib
+import copy
+import functools
+import glob
+import importlib.metadata
+import inspect
+import json
+import math
+import os
+import random
+import re
+import shutil
+import sys
+import tempfile
+import time
+import warnings
+from collections.abc import Iterator, Mapping
+from functools import partial
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Callable, Optional, Union
+
+
+# Integrations must be imported before ML frameworks:
+# ruff: isort: off
+from .integrations import (
+ get_reporting_integration_callbacks,
+)
+
+# ruff: isort: on
+
+import huggingface_hub.utils as hf_hub_utils
+import numpy as np
+import safetensors.torch
+import torch
+import torch.distributed as dist
+from huggingface_hub import ModelCard, create_repo, upload_folder
+from packaging import version
+from torch import nn
+from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
+
+from . import __version__
+from .configuration_utils import PretrainedConfig
+from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
+from .debug_utils import DebugOption, DebugUnderflowOverflow
+from .feature_extraction_sequence_utils import SequenceFeatureExtractor
+from .feature_extraction_utils import FeatureExtractionMixin
+from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
+from .image_processing_utils import BaseImageProcessor
+from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
+from .integrations.tpu import tpu_spmd_dataloader
+from .modelcard import TrainingSummary
+from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
+from .models.auto.modeling_auto import (
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
+ MODEL_MAPPING_NAMES,
+)
+from .optimization import Adafactor, get_scheduler
+from .processing_utils import ProcessorMixin
+from .pytorch_utils import (
+ is_torch_greater_or_equal_than_2_3,
+)
+from .tokenization_utils_base import PreTrainedTokenizerBase
+from .trainer_callback import (
+ CallbackHandler,
+ DefaultFlowCallback,
+ ExportableState,
+ PrinterCallback,
+ ProgressCallback,
+ TrainerCallback,
+ TrainerControl,
+ TrainerState,
+)
+from .trainer_pt_utils import (
+ DistributedTensorGatherer,
+ EvalLoopContainer,
+ IterableDatasetShard,
+ LabelSmoother,
+ LayerWiseDummyOptimizer,
+ LengthGroupedSampler,
+ SequentialDistributedSampler,
+ distributed_broadcast_scalars,
+ distributed_concat,
+ find_batch_size,
+ get_model_param_count,
+ get_module_class_from_name,
+ get_parameter_names,
+ nested_concat,
+ nested_detach,
+ nested_numpify,
+ nested_xla_mesh_reduce,
+ reissue_pt_warnings,
+ remove_dummy_checkpoint,
+ set_rng_state_for_device,
+)
+from .trainer_utils import (
+ PREFIX_CHECKPOINT_DIR,
+ BestRun,
+ EvalLoopOutput,
+ EvalPrediction,
+ HPSearchBackend,
+ HubStrategy,
+ PredictionOutput,
+ RemoveColumnsCollator,
+ SaveStrategy,
+ TrainerMemoryTracker,
+ TrainOutput,
+ check_target_module_exists,
+ default_compute_objective,
+ denumpify_detensorize,
+ enable_full_determinism,
+ find_executable_batch_size,
+ get_last_checkpoint,
+ has_length,
+ neftune_post_forward_hook,
+ number_of_arguments,
+ seed_worker,
+ set_seed,
+ speed_metrics,
+)
+from .training_args import OptimizerNames, ParallelMode, TrainingArguments
+from .utils import (
+ ADAPTER_CONFIG_NAME,
+ ADAPTER_SAFE_WEIGHTS_NAME,
+ ADAPTER_WEIGHTS_NAME,
+ CONFIG_NAME,
+ GENERATION_CONFIG_NAME,
+ SAFE_WEIGHTS_INDEX_NAME,
+ SAFE_WEIGHTS_NAME,
+ WEIGHTS_INDEX_NAME,
+ WEIGHTS_NAME,
+ XLA_FSDPV2_MIN_VERSION,
+ PushInProgress,
+ PushToHubMixin,
+ can_return_loss,
+ check_torch_load_is_safe,
+ find_labels,
+ is_accelerate_available,
+ is_apollo_torch_available,
+ is_bitsandbytes_available,
+ is_datasets_available,
+ is_galore_torch_available,
+ is_grokadamw_available,
+ is_in_notebook,
+ is_liger_kernel_available,
+ is_lomo_available,
+ is_peft_available,
+ is_sagemaker_dp_enabled,
+ is_sagemaker_mp_enabled,
+ is_schedulefree_available,
+ is_torch_hpu_available,
+ is_torch_mlu_available,
+ is_torch_mps_available,
+ is_torch_musa_available,
+ is_torch_neuroncore_available,
+ is_torch_npu_available,
+ is_torch_optimi_available,
+ is_torch_xla_available,
+ is_torch_xpu_available,
+ is_torchao_available,
+ logging,
+ strtobool,
+)
+from .utils.deprecation import deprecate_kwarg
+from .utils.import_utils import requires
+from .utils.quantization_config import QuantizationMethod
+
+
+DEFAULT_CALLBACKS = [DefaultFlowCallback]
+DEFAULT_PROGRESS_CALLBACK = ProgressCallback
+
+if is_in_notebook():
+ from .utils.notebook import NotebookProgressCallback
+
+ DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
+
+if is_datasets_available():
+ import datasets
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+ import torch_xla.debug.metrics as met
+ import torch_xla.runtime as xr
+ from torch_xla import __version__ as XLA_VERSION
+
+ IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
+ if IS_XLA_FSDPV2_POST_2_2:
+ import torch_xla.distributed.spmd as xs
+else:
+ IS_XLA_FSDPV2_POST_2_2 = False
+
+
+if is_sagemaker_mp_enabled():
+ import smdistributed.modelparallel.torch as smp
+ from smdistributed.modelparallel import __version__ as SMP_VERSION
+
+ IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
+
+ from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
+else:
+ IS_SAGEMAKER_MP_POST_1_10 = False
+
+if is_peft_available():
+ from peft import PeftModel
+
+if is_accelerate_available():
+ from accelerate import Accelerator, skip_first_batches
+ from accelerate import __version__ as accelerate_version
+ from accelerate.state import AcceleratorState
+ from accelerate.utils import (
+ AutocastKwargs,
+ DistributedDataParallelKwargs,
+ DistributedType,
+ load_fsdp_model,
+ load_fsdp_optimizer,
+ save_fsdp_model,
+ save_fsdp_optimizer,
+ )
+
+ DATA_SAMPLERS = [RandomSampler]
+ if version.parse(accelerate_version) > version.parse("1.3.0"):
+ from accelerate.utils import TorchTensorParallelPlugin
+ from accelerate.data_loader import SeedableRandomSampler
+
+ DATA_SAMPLERS += [SeedableRandomSampler]
+
+ if is_deepspeed_available():
+ from accelerate.utils import DeepSpeedSchedulerWrapper
+
+if is_accelerate_available("0.28.0"):
+ from accelerate.utils import DataLoaderConfiguration
+
+
+def _is_peft_model(model):
+ if is_peft_available():
+ classes_to_check = (PeftModel,)
+ # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321
+ if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"):
+ from peft import PeftMixedModel
+
+ classes_to_check = (*classes_to_check, PeftMixedModel)
+ return isinstance(model, classes_to_check)
+ return False
+
+
+def _get_fsdp_ckpt_kwargs():
+ # TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release
+ if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters):
+ return {"adapter_only": True}
+ else:
+ return {}
+
+
+def safe_globals():
+ # Starting from version 2.4 PyTorch introduces a check for the objects loaded
+ # with torch.load(weights_only=True). Starting from 2.6 weights_only=True becomes
+ # a default and requires allowlisting of objects being loaded.
+ # See: https://github.com/pytorch/pytorch/pull/137602
+ # See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
+ # See: https://github.com/huggingface/accelerate/pull/3036
+ if version.parse(torch.__version__).release < version.parse("2.6").release:
+ return contextlib.nullcontext()
+
+ np_core = np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core
+ allowlist = [np_core.multiarray._reconstruct, np.ndarray, np.dtype]
+ # numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for
+ # all versions of numpy
+ allowlist += [type(np.dtype(np.uint32))]
+
+ return torch.serialization.safe_globals(allowlist)
+
+
+if TYPE_CHECKING:
+ import optuna
+
+logger = logging.get_logger(__name__)
+
+
+# Name of the files used for checkpointing
+TRAINING_ARGS_NAME = "training_args.bin"
+TRAINER_STATE_NAME = "trainer_state.json"
+OPTIMIZER_NAME = "optimizer.pt"
+SCALER_NAME = "scaler.pt"
+OPTIMIZER_NAME_BIN = "optimizer.bin"
+SCHEDULER_NAME = "scheduler.pt"
+FSDP_MODEL_NAME = "pytorch_model_fsdp"
+
+
+@requires(
+ backends=(
+ "torch",
+ "accelerate",
+ )
+)
+class Trainer:
+ """
+ Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
+
+ Args:
+ model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
+ The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
+
+
+
+ [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
+ your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers
+ models.
+
+
+
+ args ([`TrainingArguments`], *optional*):
+ The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
+ `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
+ data_collator (`DataCollator`, *optional*):
+ The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
+ default to [`default_data_collator`] if no `processing_class` is provided, an instance of
+ [`DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or tokenizer.
+ train_dataset (Union[`torch.utils.data.Dataset`, `torch.utils.data.IterableDataset`, `datasets.Dataset`], *optional*):
+ The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
+ `model.forward()` method are automatically removed.
+
+ Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
+ distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
+ `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
+ manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
+ sets the seed of the RNGs used.
+ eval_dataset (Union[`torch.utils.data.Dataset`, dict[str, `torch.utils.data.Dataset`, `datasets.Dataset`]), *optional*):
+ The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
+ `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
+ dataset prepending the dictionary key to the metric name.
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
+ reuse the fine-tuned model.
+ This supersedes the `tokenizer` argument, which is now deprecated.
+ model_init (`Callable[[], PreTrainedModel]`, *optional*):
+ A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
+ from a new instance of the model as given by this function.
+
+ The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
+ be able to choose different architectures according to hyper parameters (such as layer count, sizes of
+ inner layers, dropout probabilities etc).
+ compute_loss_func (`Callable`, *optional*):
+ A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated
+ batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss function](https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618) used by [`Trainer`].
+ compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
+ The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
+ a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
+ `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
+ after the last eval batch to signal that the function needs to calculate and return the global summary
+ statistics rather than accumulating the batch-level statistics
+ callbacks (List of [`TrainerCallback`], *optional*):
+ A list of callbacks to customize the training loop. Will add those to the list of default callbacks
+ detailed in [here](callback).
+
+ If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
+ model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
+ optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], dict[str, Any]]`, *optional*):
+ A tuple containing the optimizer class and keyword arguments to use.
+ Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument.
+
+ Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer.
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
+ A function that preprocess the logits right before caching them at each evaluation step. Must take two
+ tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
+ by this function will be reflected in the predictions received by `compute_metrics`.
+
+ Note that the labels (second parameter) will be `None` if the dataset does not have them.
+
+ Important attributes:
+
+ - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
+ subclass.
+ - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
+ original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
+ the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
+ model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
+ - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
+ data parallelism, this means some of the model layers are split on different GPUs).
+ - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
+ to `False` if model parallel or deepspeed is used, or if the default
+ `TrainingArguments.place_model_on_device` is overridden to return `False` .
+ - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
+ in `train`)
+
+ """
+
+ # Those are used as methods of the Trainer in examples.
+ from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
+
+ @deprecate_kwarg("tokenizer", new_name="processing_class", version="5.0.0", raise_if_both_names=True)
+ def __init__(
+ self,
+ model: Union[PreTrainedModel, nn.Module, None] = None,
+ args: Optional[TrainingArguments] = None,
+ data_collator: Optional[DataCollator] = None,
+ train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None,
+ processing_class: Optional[
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
+ ] = None,
+ model_init: Optional[Callable[..., PreTrainedModel]] = None,
+ compute_loss_func: Optional[Callable] = None,
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
+ callbacks: Optional[list[TrainerCallback]] = None,
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
+ optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
+ ):
+ if args is None:
+ output_dir = "tmp_trainer"
+ logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
+ args = TrainingArguments(output_dir=output_dir)
+ if args.batch_eval_metrics and compute_metrics is not None:
+ if "compute_result" not in inspect.signature(compute_metrics).parameters:
+ raise ValueError(
+ "When using `batch_eval_metrics`, your `compute_metrics` function must take a `compute_result`"
+ " boolean argument which will be triggered after the last batch of the eval set to signal that the"
+ " summary statistics should be returned by the function."
+ )
+ if args.eval_strategy is not None and args.eval_strategy != "no" and eval_dataset is None:
+ raise ValueError(
+ f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. "
+ )
+ if args.save_strategy == SaveStrategy.BEST or args.load_best_model_at_end:
+ if args.metric_for_best_model is None:
+ raise ValueError(
+ "`args.metric_for_best_model` must be provided when using 'best' save_strategy or if `args.load_best_model_at_end` is set to `True`."
+ )
+
+ self.args = args
+ self.compute_loss_func = compute_loss_func
+ # Seed must be set before instantiating the model when using model
+ enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
+
+ self.hp_name = None
+ self.deepspeed = None
+ self.is_in_train = False
+ self.model = model
+ self.create_accelerator_and_postprocess()
+
+ # memory metrics - must set up as early as possible
+ self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
+ self._memory_tracker.start()
+
+ # set the correct log level depending on the node
+ log_level = args.get_process_log_level()
+ logging.set_verbosity(log_level)
+
+ # force device and distributed setup init explicitly
+ args._setup_devices
+
+ if model is None:
+ if model_init is not None:
+ self.model_init = model_init
+ model = self.call_model_init()
+ else:
+ raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
+ else:
+ if model_init is not None:
+ warnings.warn(
+ "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
+ " overwrite your model when calling the `train` method. This will become a fatal error in the next"
+ " release.",
+ FutureWarning,
+ )
+ self.model_init = model_init
+
+ if model.__class__.__name__ in MODEL_MAPPING_NAMES:
+ raise ValueError(
+ f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
+ "computes hidden states and does not accept any labels. You should choose a model with a head "
+ "suitable for your task like any of the `AutoModelForXxx` listed at "
+ "https://huggingface.co/docs/transformers/model_doc/auto"
+ )
+
+ if getattr(model, "is_parallelizable", False) and getattr(model, "model_parallel", False):
+ self.is_model_parallel = True
+ else:
+ self.is_model_parallel = False
+
+ if getattr(model, "hf_device_map", None) is not None:
+ devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]
+ if len(devices) > 1:
+ self.is_model_parallel = True
+ elif len(devices) == 1:
+ self.is_model_parallel = self.args.device != torch.device(devices[0])
+ else:
+ self.is_model_parallel = False
+
+ # warn users
+ if self.is_model_parallel:
+ logger.info(
+ "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set"
+ " to `True` to avoid any unexpected behavior such as device placement mismatching."
+ )
+
+ if self.args.use_liger_kernel:
+ if is_liger_kernel_available():
+ from liger_kernel.transformers import _apply_liger_kernel_to_instance
+
+ # Prepare kernel config - use provided config or default (empty dict for default behavior)
+ kernel_config = self.args.liger_kernel_config if self.args.liger_kernel_config is not None else {}
+
+ if isinstance(model, PreTrainedModel):
+ # Patch the model with liger kernels. Use the specified or default kernel configurations.
+ _apply_liger_kernel_to_instance(model=model, **kernel_config)
+ elif hasattr(model, "get_base_model") and isinstance(model.get_base_model(), PreTrainedModel):
+ # Patch the base model with liger kernels where model is a PeftModel. Use the specified or default kernel configurations.
+ _apply_liger_kernel_to_instance(model=model.get_base_model(), **kernel_config)
+ else:
+ logger.warning(
+ "The model is not an instance of PreTrainedModel. No liger kernels will be applied."
+ )
+ else:
+ raise ImportError(
+ "You have set `use_liger_kernel` to `True` but liger-kernel >= 0.3.0 is not available. "
+ "Please install it with `pip install liger-kernel`"
+ )
+
+ _is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
+ model, "_hf_peft_config_loaded", False
+ )
+ _quantization_method_supports_training = (
+ getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable
+ )
+
+ _is_model_quantized_and_qat_trainable = getattr(model, "hf_quantizer", None) is not None and getattr(
+ model.hf_quantizer, "is_qat_trainable", False
+ )
+
+ # Filter out quantized + compiled models
+ if _is_quantized_and_base_model and hasattr(model, "_orig_mod"):
+ raise ValueError(
+ "You cannot fine-tune quantized model with `torch.compile()` make sure to pass a non-compiled model when fine-tuning a quantized model with PEFT"
+ )
+
+ # At this stage the model is already loaded
+ if _is_quantized_and_base_model and not _is_peft_model(model) and not _is_model_quantized_and_qat_trainable:
+ raise ValueError(
+ "You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of"
+ " the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"
+ " for more details"
+ )
+ elif _is_quantized_and_base_model and not _quantization_method_supports_training:
+ raise ValueError(
+ f"The model you are trying to fine-tune is quantized with {model.hf_quantizer.quantization_config.quant_method}"
+ " but that quantization method do not support training. Please open an issue on GitHub: https://github.com/huggingface/transformers"
+ f" to request the support for training support for {model.hf_quantizer.quantization_config.quant_method}"
+ )
+
+ self.is_fsdp_xla_enabled = args.fsdp_config["xla"]
+ if len(args.fsdp) > 0:
+ if self.is_deepspeed_enabled:
+ raise ValueError(
+ "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
+ )
+ if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED:
+ raise ValueError("Using fsdp only works in distributed training.")
+
+ # one place to sort out whether to place the model on device or not
+ # postpone switching model to cuda when:
+ # 1. MP - since we are trying to fit a much bigger than 1 gpu model
+ # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
+ # and we only use deepspeed for training at the moment
+ # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
+ # 4. FSDP - same as MP
+ self.place_model_on_device = args.place_model_on_device
+ if (
+ self.is_model_parallel
+ or self.is_deepspeed_enabled
+ or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
+ or self.is_fsdp_xla_enabled
+ or self.is_fsdp_enabled
+ ):
+ self.place_model_on_device = False
+
+ default_collator = (
+ DataCollatorWithPadding(processing_class)
+ if processing_class is not None
+ and isinstance(processing_class, (PreTrainedTokenizerBase, SequenceFeatureExtractor))
+ else default_data_collator
+ )
+ self.data_collator = data_collator if data_collator is not None else default_collator
+ self.train_dataset = train_dataset
+ self.eval_dataset = eval_dataset
+ self.processing_class = processing_class
+
+ # Bnb Quantized models doesn't support `.to` operation.
+ if (
+ self.place_model_on_device
+ and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES
+ ):
+ self._move_model_to_device(model, args.device)
+
+ # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
+ if self.is_model_parallel:
+ self.args._n_gpu = 1
+
+ # later use `self.model is self.model_wrapped` to check if it's wrapped or not
+ self.model_wrapped = model
+ self.model = model
+
+ # Just in case the model was wrapped outside of the `Trainer`
+ unwrapped_model = self.accelerator.unwrap_model(model)
+ # We also unwrap peft model
+ if _is_peft_model(unwrapped_model):
+ if hasattr(unwrapped_model, "get_base_model"):
+ unwrapped_model = unwrapped_model.get_base_model()
+ elif hasattr(unwrapped_model, "base_model") and hasattr(unwrapped_model.base_model, "model"):
+ unwrapped_model = unwrapped_model.base_model.model
+ else:
+ raise AttributeError("Cannot extract base model safely from this PEFT wrapper.")
+
+ # Check if the model has explicit setup for loss kwargs,
+ # if not, check if `**kwargs` are in model.forward
+ if hasattr(unwrapped_model, "accepts_loss_kwargs"):
+ self.model_accepts_loss_kwargs = unwrapped_model.accepts_loss_kwargs
+ else:
+ forward_params = inspect.signature(unwrapped_model.forward).parameters
+ self.model_accepts_loss_kwargs = any(
+ k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values()
+ )
+
+ self.neftune_noise_alpha = args.neftune_noise_alpha
+
+ self.compute_metrics = compute_metrics
+ self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
+ self.optimizer, self.lr_scheduler = optimizers
+ self.optimizer_cls_and_kwargs = optimizer_cls_and_kwargs
+ if self.optimizer_cls_and_kwargs is not None and self.optimizer is not None:
+ raise RuntimeError("Passing both `optimizers` and `optimizer_cls_and_kwargs` arguments is incompatible.")
+ if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
+ raise RuntimeError(
+ "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
+ "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
+ )
+ if is_torch_xla_available() and self.optimizer is not None:
+ for param in self.model.parameters():
+ model_device = param.device
+ break
+ for param_group in self.optimizer.param_groups:
+ if len(param_group["params"]) > 0:
+ optimizer_device = param_group["params"][0].device
+ break
+ if model_device != optimizer_device:
+ raise ValueError(
+ "The model and the optimizer parameters are not on the same device, which probably means you"
+ " created an optimizer around your model **before** putting on the device and passing it to the"
+ " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
+ " `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
+ )
+ if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and (
+ self.optimizer is not None or self.lr_scheduler is not None
+ ):
+ raise RuntimeError(
+ "Passing `optimizers` is not allowed if PyTorch FSDP is enabled. "
+ "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
+ )
+ default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
+ callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
+ self.callback_handler = CallbackHandler(
+ callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
+ )
+ self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
+
+ # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
+ self._loggers_initialized = False
+
+ # Create distant repo and output directory if needed
+ self.hub_model_id = None
+ if self.args.push_to_hub:
+ self.init_hf_repo()
+ if self.args.should_save:
+ os.makedirs(self.args.output_dir, exist_ok=True)
+
+ if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
+ raise TypeError("The `data_collator` should be a simple callable (function, class with `__call__`).")
+
+ if args.max_steps > 0 and args.num_train_epochs > 0:
+ logger.info("max_steps is given, it will override any value given in num_train_epochs")
+
+ if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
+ raise ValueError(
+ "The train_dataset does not implement __len__, max_steps has to be specified. "
+ "The number of steps needs to be known in advance for the learning rate scheduler."
+ )
+
+ if (
+ train_dataset is not None
+ and isinstance(train_dataset, torch.utils.data.IterableDataset)
+ and args.group_by_length
+ ):
+ raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset")
+
+ self._signature_columns = None
+
+ # Mixed precision setup
+ self.use_apex = False
+ self.use_cpu_amp = False
+
+ # Mixed precision setup for SageMaker Model Parallel
+ if is_sagemaker_mp_enabled():
+ # BF16 + model parallelism in SageMaker: currently not supported, raise an error
+ if args.bf16:
+ raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
+
+ if IS_SAGEMAKER_MP_POST_1_10:
+ # When there's mismatch between SMP config and trainer argument, use SMP config as truth
+ if args.fp16 != smp.state.cfg.fp16:
+ logger.warning(
+ f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
+ f"but FP16 provided in trainer argument is {args.fp16}, "
+ f"setting to {smp.state.cfg.fp16}"
+ )
+ args.fp16 = smp.state.cfg.fp16
+ else:
+ # smp < 1.10 does not support fp16 in trainer.
+ if hasattr(smp.state.cfg, "fp16"):
+ logger.warning(
+ f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
+ "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
+ )
+ if (args.fp16 or args.bf16) and args.half_precision_backend == "auto":
+ if args.device == torch.device("cpu"):
+ if args.fp16:
+ if not is_torch_greater_or_equal_than_2_3:
+ raise ValueError("Tried to use `fp16` but it is not supported on cpu")
+ else:
+ args.half_precision_backend = "cpu_amp"
+ logger.info(f"Using {args.half_precision_backend} half precision backend")
+
+ if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()):
+ # deepspeed and SageMaker Model Parallel manage their own half precision
+ if args.half_precision_backend == "cpu_amp":
+ self.use_cpu_amp = True
+ self.amp_dtype = torch.bfloat16
+ elif args.half_precision_backend == "apex":
+ self.use_apex = True
+
+ # Label smoothing
+ if self.args.label_smoothing_factor != 0:
+ self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
+ else:
+ self.label_smoother = None
+
+ # Check for multi-label classification incompatibility
+ if self.args.label_smoothing_factor > 0:
+ if getattr(self.model.config, "problem_type", None) == "multi_label_classification":
+ warnings.warn(
+ "Label smoothing is not compatible with multi-label classification. "
+ "Disabling label smoothing for this training run.",
+ UserWarning,
+ )
+ self.label_smoother = None
+
+ self.control = TrainerControl()
+
+ self.state = TrainerState(
+ is_local_process_zero=self.is_local_process_zero(),
+ is_world_process_zero=self.is_world_process_zero(),
+ stateful_callbacks=[
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
+ ],
+ )
+ # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
+ # returned to 0 every time flos need to be logged
+ self.current_flos = 0
+ self.hp_search_backend = None
+
+ model_to_inspect = self.model
+ if _is_peft_model(self.model):
+ if hasattr(self.model, "get_base_model"):
+ model_to_inspect = self.model.get_base_model()
+ else:
+ # PeftMixedModel do not provide a `get_base_model` method
+ model_to_inspect = self.model.base_model.model
+ default_label_names = find_labels(model_to_inspect.__class__)
+ self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
+ self.can_return_loss = can_return_loss(model_to_inspect.__class__)
+ self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
+
+ # Internal variables to help with automatic batch size reduction
+ self._train_batch_size = args.train_batch_size
+ self._created_lr_scheduler = False
+
+ # very last
+ self._memory_tracker.stop_and_update_metrics()
+
+ self.is_fsdp_xla_v2_enabled = args.fsdp_config.get("xla_fsdp_v2", False)
+ if self.is_fsdp_xla_v2_enabled:
+ if not IS_XLA_FSDPV2_POST_2_2:
+ raise ValueError("FSDPv2 requires `torch_xla` 2.2 or higher.")
+ # Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper.
+ # Tensor axis is just a placeholder where it will not be used in FSDPv2.
+ num_devices = xr.global_runtime_device_count()
+ xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
+ self.is_fsdp_xla_v1_enabled = self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled
+
+ @property
+ def tokenizer(self) -> Optional[PreTrainedTokenizerBase]:
+ logger.warning("Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.")
+ return self.processing_class
+
+ @tokenizer.setter
+ def tokenizer(self, processing_class) -> None:
+ logger.warning(
+ "Trainer.tokenizer is now deprecated. You should use `Trainer.processing_class = processing_class` instead."
+ )
+ self.processing_class = processing_class
+
+ def _activate_neftune(self, model):
+ r"""
+ Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
+ https://huggingface.co/papers/2310.05914
+ """
+ unwrapped_model = self.accelerator.unwrap_model(model)
+
+ if _is_peft_model(unwrapped_model):
+ embeddings = unwrapped_model.base_model.model.get_input_embeddings()
+ else:
+ embeddings = unwrapped_model.get_input_embeddings()
+
+ del unwrapped_model
+
+ embeddings.neftune_noise_alpha = self.neftune_noise_alpha
+ hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook)
+ self.neftune_hook_handle = hook_handle
+ return model
+
+ def _deactivate_neftune(self, model):
+ """
+ Deactivates the neftune method. Make sure to call `_activate_neftune` first.
+ """
+ if not hasattr(self, "neftune_hook_handle"):
+ raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first")
+
+ unwrapped_model = self.accelerator.unwrap_model(model)
+
+ if _is_peft_model(unwrapped_model):
+ embeddings = unwrapped_model.base_model.model.get_input_embeddings()
+ else:
+ embeddings = unwrapped_model.get_input_embeddings()
+
+ self.neftune_hook_handle.remove()
+ del embeddings.neftune_noise_alpha, unwrapped_model
+
+ def add_callback(self, callback):
+ """
+ Add a callback to the current list of [`~transformers.TrainerCallback`].
+
+ Args:
+ callback (`type` or [`~transformers.TrainerCallback]`):
+ A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
+ first case, will instantiate a member of that class.
+ """
+ self.callback_handler.add_callback(callback)
+
+ def pop_callback(self, callback):
+ """
+ Remove a callback from the current list of [`~transformers.TrainerCallback`] and returns it.
+
+ If the callback is not found, returns `None` (and no error is raised).
+
+ Args:
+ callback (`type` or [`~transformers.TrainerCallback]`):
+ A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
+ first case, will pop the first member of that class found in the list of callbacks.
+
+ Returns:
+ [`~transformers.TrainerCallback`]: The callback removed, if found.
+ """
+ return self.callback_handler.pop_callback(callback)
+
+ def remove_callback(self, callback):
+ """
+ Remove a callback from the current list of [`~transformers.TrainerCallback`].
+
+ Args:
+ callback (`type` or [`~transformers.TrainerCallback]`):
+ A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
+ first case, will remove the first member of that class found in the list of callbacks.
+ """
+ self.callback_handler.remove_callback(callback)
+
+ def _move_model_to_device(self, model, device):
+ if getattr(model, "hf_device_map", None) is not None:
+ logger.warning(
+ "The model is already on multiple devices. Skipping the move to device specified in `args`."
+ )
+ return
+ model = model.to(device)
+ # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
+ if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
+ model.tie_weights()
+
+ def _align_special_tokens(self):
+ """
+ Aligns the special tokens of the tokenizer with the model configs.
+
+ A new tokens may be defined in the tokenizer for fine-tuning purposes, e.g. an "end of turn" token may be
+ added on chat models. In that case, we want the model configs to be aligned with the tokenizer, so that all
+ downstream uses work as expected. This alignment should happen before training, to ensure the prediction step
+ uses the new tokens as well.
+ """
+ if isinstance(self.processing_class, ProcessorMixin):
+ tokenizer: PreTrainedTokenizerBase = self.processing_class.tokenizer
+ else:
+ tokenizer = self.processing_class
+ model_has_generation_config = (
+ hasattr(self.model, "generation_config") and self.model.generation_config is not None
+ )
+ updated_tokens = {}
+
+ # 1 - Align EOS token. EOS is more complex than the others, as `generation_config` may hold more than one EOS
+ # token.
+ tokenizer_has_new_eos = tokenizer.eos_token_id != self.model.config.eos_token_id
+ if model_has_generation_config:
+ # `generation_config.eos_token_id` is None: direct comparison
+ if self.model.generation_config.eos_token_id is None:
+ tokenizer_has_new_eos |= tokenizer.eos_token_id != self.model.generation_config.eos_token_id
+ else:
+ # `generation_config.eos_token_id` is an `int`: convert it to list (and continue below)
+ if isinstance(self.model.generation_config.eos_token_id, int):
+ self.model.generation_config.eos_token_id = [self.model.generation_config.eos_token_id]
+ # `generation_config.eos_token_id` is a `list`: check if the tokenizer's EOS token is in the list
+ tokenizer_has_new_eos |= tokenizer.eos_token_id not in self.model.generation_config.eos_token_id
+
+ if tokenizer_has_new_eos:
+ updated_tokens["eos_token_id"] = tokenizer.eos_token_id
+ self.model.config.eos_token_id = tokenizer.eos_token_id
+ # The generation config may hold more than one EOS token. We preserve the original EOS tokens: any of the
+ # EOS tokens defined here will halt generation.
+ if model_has_generation_config:
+ all_eos_tokens = [tokenizer.eos_token_id]
+ if self.model.generation_config.eos_token_id is not None:
+ all_eos_tokens += list(self.model.generation_config.eos_token_id)
+ self.model.generation_config.eos_token_id = [token for token in all_eos_tokens if token is not None]
+
+ # 2 - Align BOS
+ tokenizer_has_new_bos = tokenizer.bos_token_id != self.model.config.bos_token_id
+ if model_has_generation_config:
+ tokenizer_has_new_bos |= tokenizer.bos_token_id != self.model.generation_config.bos_token_id
+
+ if tokenizer_has_new_bos:
+ updated_tokens["bos_token_id"] = tokenizer.bos_token_id
+ self.model.config.bos_token_id = tokenizer.bos_token_id
+ if model_has_generation_config:
+ self.model.generation_config.bos_token_id = tokenizer.bos_token_id
+
+ # 3 - Align PAD
+ tokenizer_has_new_pad = tokenizer.pad_token_id != self.model.config.pad_token_id
+ if model_has_generation_config:
+ tokenizer_has_new_pad |= tokenizer.pad_token_id != self.model.generation_config.pad_token_id
+
+ if tokenizer_has_new_pad:
+ updated_tokens["pad_token_id"] = tokenizer.pad_token_id
+ self.model.config.pad_token_id = tokenizer.pad_token_id
+ if model_has_generation_config:
+ self.model.generation_config.pad_token_id = tokenizer.pad_token_id
+
+ # 4 - Warn users about the changes
+ if len(updated_tokens) > 0:
+ logger.warning(
+ "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. "
+ "The model config and generation config were aligned accordingly, being updated with the tokenizer's "
+ f"values. Updated tokens: {updated_tokens}."
+ )
+
+ def _set_signature_columns_if_needed(self):
+ if self._signature_columns is None:
+ # Inspect model forward signature to keep only the arguments it accepts.
+ model_to_inspect = self.model
+ if _is_peft_model(self.model):
+ if hasattr(self.model, "get_base_model"):
+ model_to_inspect = self.model.get_base_model()
+ else:
+ # PeftMixedModel do not provide a `get_base_model` method
+ model_to_inspect = self.model.base_model.model
+ signature = inspect.signature(model_to_inspect.forward)
+ self._signature_columns = list(signature.parameters.keys())
+ # Labels may be named label or label_ids, the default data collator handles that.
+ self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
+
+ def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
+ if not self.args.remove_unused_columns:
+ return dataset
+ self._set_signature_columns_if_needed()
+ signature_columns = self._signature_columns
+
+ ignored_columns = list(set(dataset.column_names) - set(signature_columns))
+ if len(ignored_columns) > 0:
+ dset_description = "" if description is None else f"in the {description} set"
+ logger.info(
+ f"The following columns {dset_description} don't have a corresponding argument in "
+ f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
+ f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
+ " you can safely ignore this message."
+ )
+
+ columns = [k for k in signature_columns if k in dataset.column_names]
+ if len(columns) == 0:
+ raise ValueError(
+ f"No columns in the dataset match the model's forward method signature: ({', '.join(signature_columns)}). "
+ f"The following columns have been ignored: [{', '.join(ignored_columns)}]. "
+ "Please check the dataset and model. You may need to set `remove_unused_columns=False` in `TrainingArguments`."
+ )
+
+ if version.parse(datasets.__version__) < version.parse("1.4.0"):
+ dataset.set_format(
+ type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
+ )
+ return dataset
+ else:
+ return dataset.remove_columns(ignored_columns)
+
+ def _get_collator_with_removed_columns(
+ self, data_collator: Callable, description: Optional[str] = None
+ ) -> Callable:
+ """Wrap the data collator in a callable removing unused columns."""
+ if not self.args.remove_unused_columns:
+ return data_collator
+ self._set_signature_columns_if_needed()
+ signature_columns = self._signature_columns
+
+ remove_columns_collator = RemoveColumnsCollator(
+ data_collator=data_collator,
+ signature_columns=signature_columns,
+ logger=logger,
+ description=description,
+ model_name=self.model.__class__.__name__,
+ )
+ return remove_columns_collator
+
+ def _get_train_sampler(self, train_dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]:
+ if train_dataset is None:
+ train_dataset = self.train_dataset
+ if train_dataset is None or not has_length(train_dataset):
+ return None
+
+ # Build the sampler.
+ if self.args.group_by_length:
+ if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
+ lengths = (
+ train_dataset[self.args.length_column_name]
+ if self.args.length_column_name in train_dataset.column_names
+ else None
+ )
+ else:
+ lengths = None
+ model_input_name = (
+ self.processing_class.model_input_names[0] if self.processing_class is not None else None
+ )
+ return LengthGroupedSampler(
+ self.args.train_batch_size * self.args.gradient_accumulation_steps,
+ dataset=train_dataset,
+ lengths=lengths,
+ model_input_name=model_input_name,
+ )
+
+ else:
+ return RandomSampler(train_dataset)
+
+ def _get_dataloader(
+ self,
+ dataset: Dataset,
+ description: str,
+ batch_size: int,
+ sampler_fn: Optional[Callable[[Dataset], torch.utils.data.Sampler]] = None,
+ is_training: bool = False,
+ dataloader_key: Optional[str] = None,
+ ) -> DataLoader:
+ """Create a [`~torch.utils.data.DataLoader`] from the given dataset."""
+
+ data_collator = self.data_collator
+ if is_datasets_available() and isinstance(dataset, datasets.Dataset):
+ dataset = self._remove_unused_columns(dataset, description=description)
+ else:
+ data_collator = self._get_collator_with_removed_columns(self.data_collator, description=description)
+
+ dataloader_params = {
+ "batch_size": batch_size,
+ "collate_fn": data_collator,
+ "num_workers": self.args.dataloader_num_workers,
+ "pin_memory": self.args.dataloader_pin_memory,
+ "persistent_workers": self.args.dataloader_persistent_workers,
+ }
+
+ if not isinstance(dataset, torch.utils.data.IterableDataset):
+ if sampler_fn is not None:
+ dataloader_params["sampler"] = sampler_fn(dataset)
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
+ dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
+ if is_training:
+ dataloader_params["worker_init_fn"] = partial(
+ seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index
+ )
+
+ dataloader = self.accelerator.prepare(DataLoader(dataset, **dataloader_params))
+
+ # Store the prepared dataloader for subsequent evaluations if using persistent workers.
+ if dataloader_key is not None and self.args.dataloader_persistent_workers:
+ if hasattr(self, "_eval_dataloaders"):
+ self._eval_dataloaders[dataloader_key] = dataloader
+ else:
+ self._eval_dataloaders = {dataloader_key: dataloader}
+
+ return dataloader
+
+ def get_train_dataloader(self) -> DataLoader:
+ """
+ Returns the training [`~torch.utils.data.DataLoader`].
+
+ Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
+ training if necessary) otherwise.
+
+ Subclass and override this method if you want to inject some custom behavior.
+ """
+ if self.train_dataset is None:
+ raise ValueError("Trainer: training requires a train_dataset.")
+
+ return self._get_dataloader(
+ dataset=self.train_dataset,
+ description="Training",
+ batch_size=self._train_batch_size,
+ sampler_fn=self._get_train_sampler,
+ is_training=True,
+ )
+
+ def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
+ if eval_dataset is None or not has_length(eval_dataset):
+ return None
+ # Build the sampler.
+
+ # Deprecated code
+ if self.args.use_legacy_prediction_loop:
+ if is_torch_xla_available():
+ return SequentialDistributedSampler(
+ eval_dataset, num_replicas=xr.world_size(), rank=xr.global_ordinal()
+ )
+ elif is_sagemaker_mp_enabled():
+ return SequentialDistributedSampler(
+ eval_dataset,
+ num_replicas=smp.dp_size(),
+ rank=smp.dp_rank(),
+ batch_size=self.args.per_device_eval_batch_size,
+ )
+ else:
+ return SequentialSampler(eval_dataset)
+
+ if self.args.group_by_length:
+ if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
+ lengths = (
+ eval_dataset[self.args.length_column_name]
+ if self.args.length_column_name in eval_dataset.column_names
+ else None
+ )
+ else:
+ lengths = None
+ model_input_name = (
+ self.processing_class.model_input_names[0] if self.processing_class is not None else None
+ )
+ return LengthGroupedSampler(
+ self.args.eval_batch_size,
+ dataset=eval_dataset,
+ lengths=lengths,
+ model_input_name=model_input_name,
+ )
+
+ if self.args.world_size <= 1:
+ return SequentialSampler(eval_dataset)
+ else:
+ return None
+
+ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
+ """
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
+
+ Subclass and override this method if you want to inject some custom behavior.
+
+ Args:
+ eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*):
+ If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed.
+ """
+ if eval_dataset is None and self.eval_dataset is None:
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
+
+ # If we have persistent workers, don't do a fork bomb especially as eval datasets
+ # don't change during training
+ dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
+ if (
+ hasattr(self, "_eval_dataloaders")
+ and dataloader_key in self._eval_dataloaders
+ and self.args.dataloader_persistent_workers
+ ):
+ return self._eval_dataloaders[dataloader_key]
+
+ eval_dataset = (
+ self.eval_dataset[eval_dataset]
+ if isinstance(eval_dataset, str)
+ else eval_dataset
+ if eval_dataset is not None
+ else self.eval_dataset
+ )
+
+ return self._get_dataloader(
+ dataset=eval_dataset,
+ description="Evaluation",
+ batch_size=self.args.eval_batch_size,
+ sampler_fn=self._get_eval_sampler,
+ dataloader_key=dataloader_key,
+ )
+
+ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
+ """
+ Returns the test [`~torch.utils.data.DataLoader`].
+
+ Subclass and override this method if you want to inject some custom behavior.
+
+ Args:
+ test_dataset (`torch.utils.data.Dataset`, *optional*):
+ The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
+ `model.forward()` method are automatically removed. It must implement `__len__`.
+ """
+ return self._get_dataloader(
+ dataset=test_dataset,
+ description="test",
+ batch_size=self.args.eval_batch_size,
+ sampler_fn=self._get_eval_sampler,
+ )
+
+ def create_optimizer_and_scheduler(self, num_training_steps: int):
+ """
+ Setup the optimizer and the learning rate scheduler.
+
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+ Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
+ `create_scheduler`) in a subclass.
+ """
+ self.create_optimizer()
+ if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
+ # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
+ optimizer = self.optimizer.optimizer
+ else:
+ optimizer = self.optimizer
+ self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
+
+ def get_decay_parameter_names(self, model) -> list[str]:
+ """
+ Get all parameter names that weight decay will be applied to.
+
+ This function filters out parameters in two ways:
+ 1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)
+ 2. By parameter name patterns (containing 'bias', or variation of 'norm')
+ """
+ forbidden_name_patterns = [r"bias", r"layernorm", r"rmsnorm", r"(?:^|\.)norm(?:$|\.)", r"_norm(?:$|\.)"]
+ decay_parameters = get_parameter_names(model, [nn.LayerNorm], forbidden_name_patterns)
+ return decay_parameters
+
+ def create_optimizer(self):
+ """
+ Setup the optimizer.
+
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
+ """
+ opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
+
+ if self.optimizer is None:
+ decay_parameters = self.get_decay_parameter_names(opt_model)
+ optimizer_grouped_parameters = [
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ },
+ ]
+
+ if self.optimizer_cls_and_kwargs is not None:
+ optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
+ else:
+ optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
+
+ # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
+ # e.g. for GaLore optimizer.
+ if "params" in optimizer_kwargs:
+ optimizer_grouped_parameters = optimizer_kwargs.pop("params")
+
+ # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
+ # e.g. for LOMO optimizer.
+ if "model" in optimizer_kwargs:
+ optimizer_grouped_parameters = optimizer_kwargs.pop("model")
+
+ # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
+ # to avoid arguments conflicts.
+ if "optimizer_dict" in optimizer_kwargs:
+ optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
+
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
+
+ if "bitsandbytes" in str(optimizer_cls) and optimizer_kwargs.get("optim_bits", None) == 8:
+ import bitsandbytes
+
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
+
+ skipped = 0
+ for module in opt_model.modules():
+ if isinstance(module, nn.Embedding):
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
+ logger.info(f"skipped {module}: {skipped / 2**20}M params")
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
+ logger.info(f"skipped: {skipped / 2**20}M params")
+
+ if is_sagemaker_mp_enabled():
+ self.optimizer = smp.DistributedOptimizer(self.optimizer)
+
+ return self.optimizer
+
+ def get_num_trainable_parameters(self):
+ """
+ Get the number of trainable parameters.
+ """
+ return sum(p.numel() for p in self.model.parameters() if p.requires_grad)
+
+ def get_learning_rates(self):
+ """
+ Returns the learning rate of each parameter from self.optimizer.
+ """
+ if self.optimizer is None:
+ raise ValueError("Trainer optimizer is None, please make sure you have setup the optimizer before.")
+ return [group["lr"] for group in self.optimizer.param_groups]
+
+ def get_optimizer_group(self, param: Optional[Union[str, torch.nn.parameter.Parameter]] = None):
+ """
+ Returns optimizer group for a parameter if given, else returns all optimizer groups for params.
+
+ Args:
+ param (`str` or `torch.nn.parameter.Parameter`, *optional*):
+ The parameter for which optimizer group needs to be returned.
+ """
+ if self.optimizer is None:
+ raise ValueError("Trainer optimizer is None, please make sure you have setup the optimizer before.")
+ if param is not None:
+ for group in self.optimizer.param_groups:
+ if param in group["params"]:
+ return group
+ return [group["params"] for group in self.optimizer.param_groups]
+
+ @staticmethod
+ def get_optimizer_cls_and_kwargs(
+ args: TrainingArguments, model: Optional[PreTrainedModel] = None
+ ) -> tuple[Any, Any]:
+ """
+ Returns the optimizer class and optimizer parameters based on the training arguments.
+
+ Args:
+ args (`transformers.training_args.TrainingArguments`):
+ The training arguments for the training session.
+
+ """
+
+ # parse args.optim_args
+ optim_args = {}
+ if args.optim_args:
+ for mapping in args.optim_args.replace(" ", "").split(","):
+ key, value = mapping.split("=")
+ optim_args[key] = value
+
+ optimizer_kwargs = {"lr": args.learning_rate}
+
+ adam_kwargs = {
+ "betas": (args.adam_beta1, args.adam_beta2),
+ "eps": args.adam_epsilon,
+ }
+
+ def setup_low_rank_optimizer(
+ optimizer_name: str,
+ optimizer_mapping: dict[str, Any],
+ optim_kwargs: dict[str, Any],
+ is_layerwise_supported: bool = True,
+ ) -> tuple[Any, Any]:
+ """
+ Helper function to set up low-rank optimizers like GaLore and Apollo.
+
+ Args:
+ optimizer_name (str): Name of the optimizer.
+ optimizer_mapping (dict): Mapping of optimizer names to their classes.
+ optim_kwargs (dict): Keyword arguments for the optimizer.
+ is_layerwise_supported (bool): Whether layerwise optimization is supported.
+
+ Returns:
+ tuple[Any, Any]: Optimizer class and updated optimizer kwargs.
+ """
+ is_layerwise = optimizer_name.lower().endswith("layerwise")
+ if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED and is_layerwise_supported:
+ raise NotImplementedError(f"Layer-wise {optimizer_name} does not support DDP at this time")
+
+ optimizer_cls = optimizer_mapping[optimizer_name]
+
+ if args.optim_target_modules is None:
+ raise ValueError(f"You need to define `optim_target_modules` to use {optimizer_name} optimizers")
+
+ if not isinstance(args.optim_target_modules, (list, str)):
+ raise TypeError(
+ f"`optim_target_modules` must be a list of strings, a regex string, or 'all-linear'. Got: {args.optim_target_modules}"
+ )
+
+ if model is None:
+ raise ValueError(f"You need to pass a model to initialize {optimizer_name} optimizer.")
+
+ all_linear = (
+ isinstance(args.optim_target_modules, str)
+ and args.optim_target_modules.replace("_", "-") == "all-linear"
+ )
+
+ target_params_names = []
+ for module_name, module in model.named_modules():
+ target_module_exists, is_regex = check_target_module_exists(
+ args.optim_target_modules, module_name, return_is_regex=True
+ )
+
+ if not isinstance(module, nn.Linear):
+ if target_module_exists and not is_regex:
+ logger.warning(
+ f"{module_name} matched but ignored. {optimizer_name} only supports linear layers."
+ )
+ continue
+
+ if not target_module_exists and not all_linear:
+ continue
+
+ target_params_names.append(module_name + ".weight")
+
+ if len(target_params_names) == 0:
+ raise ValueError(f"No target modules found for {optimizer_name} ({args.optim_target_modules}).")
+
+ target_params = [p for n, p in model.named_parameters() if n in target_params_names]
+ non_target_params = [p for n, p in model.named_parameters() if n not in target_params_names]
+ optim_kwargs.update(optim_args)
+
+ param_groups = [
+ {"params": non_target_params},
+ {"params": target_params, **optim_kwargs},
+ ]
+
+ if is_layerwise:
+ if args.gradient_accumulation_steps != 1:
+ raise ValueError(f"Layerwise {optimizer_name} does not support gradient accumulation!")
+
+ optimizer_dict = {}
+ for param in non_target_params:
+ optimizer_dict[param] = optimizer_cls([{"params": [param]}], **optimizer_kwargs)
+ for param in target_params:
+ optimizer_dict[param] = optimizer_cls([{"params": [param], **optim_kwargs}], **optimizer_kwargs)
+
+ def optimizer_hook(param):
+ if param.grad is not None:
+ optimizer_dict[param].step()
+ optimizer_dict[param].zero_grad()
+
+ for param in model.parameters():
+ if param.requires_grad:
+ param.register_post_accumulate_grad_hook(optimizer_hook)
+
+ optimizer_cls = LayerWiseDummyOptimizer
+ optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
+
+ optimizer_kwargs.update({"params": param_groups})
+ return optimizer_cls, optimizer_kwargs
+
+ if args.optim == OptimizerNames.ADAFACTOR:
+ optimizer_cls = Adafactor
+ optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
+ elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:
+ from torch.optim import AdamW
+
+ optimizer_cls = AdamW
+ optimizer_kwargs.update(adam_kwargs)
+ if args.optim == OptimizerNames.ADAMW_TORCH_FUSED:
+ optimizer_kwargs.update({"fused": True})
+ elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
+ try:
+ from torch_xla.amp.syncfree import AdamW
+
+ optimizer_cls = AdamW
+ optimizer_kwargs.update(adam_kwargs)
+ except ImportError:
+ raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
+ elif args.optim == OptimizerNames.ADAMW_TORCH_NPU_FUSED:
+ try:
+ from torch_npu.optim import NpuFusedAdamW
+
+ optimizer_cls = NpuFusedAdamW
+ optimizer_kwargs.update(adam_kwargs)
+ except ImportError:
+ raise ValueError("Trainer failed to import FusedAdamW from torch_npu.")
+ elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
+ try:
+ from apex.optimizers import FusedAdam
+
+ optimizer_cls = FusedAdam
+ optimizer_kwargs.update(adam_kwargs)
+ except ImportError:
+ raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
+ elif args.optim in [
+ OptimizerNames.ADAMW_BNB,
+ OptimizerNames.ADAMW_8BIT,
+ OptimizerNames.PAGED_ADAMW,
+ OptimizerNames.PAGED_ADAMW_8BIT,
+ OptimizerNames.ADEMAMIX,
+ OptimizerNames.ADEMAMIX_8BIT,
+ OptimizerNames.PAGED_ADEMAMIX,
+ OptimizerNames.PAGED_ADEMAMIX_8BIT,
+ OptimizerNames.LION,
+ OptimizerNames.LION_8BIT,
+ OptimizerNames.PAGED_LION,
+ OptimizerNames.PAGED_LION_8BIT,
+ OptimizerNames.RMSPROP_BNB,
+ OptimizerNames.RMSPROP_8BIT,
+ OptimizerNames.RMSPROP_32BIT,
+ ]:
+ try:
+ from bitsandbytes.optim import AdamW, Lion, RMSprop
+
+ is_paged = False
+ optim_bits = 32
+ optimizer_cls = None
+ additional_optim_kwargs = adam_kwargs
+ if "paged" in args.optim:
+ is_paged = True
+ if "8bit" in args.optim:
+ optim_bits = 8
+ if "adam" in args.optim:
+ optimizer_cls = AdamW
+ elif "lion" in args.optim:
+ optimizer_cls = Lion
+ additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)}
+ elif "rmsprop" in args.optim:
+ optimizer_cls = RMSprop
+ # Above we pass all `adam_kwargs` to the optimizer, here
+ # we only pass `optim_args` which can be passed by the user.
+ additional_optim_kwargs = optim_args
+ elif "ademamix" in args.optim:
+ if is_bitsandbytes_available() and version.parse(
+ importlib.metadata.version("bitsandbytes")
+ ) < version.parse("0.44.0"):
+ raise ValueError(
+ "The AdEMAMix optimizer is not supported by your current version of `bitsandbytes`. "
+ "Please install `bitsandbytes` >= 0.44.0."
+ )
+
+ from bitsandbytes.optim import AdEMAMix
+
+ optimizer_cls = AdEMAMix
+ additional_optim_kwargs = {
+ "betas": (
+ float(optim_args.get("beta1", args.adam_beta1)),
+ float(optim_args.get("beta2", args.adam_beta2)),
+ float(optim_args.get("beta3", 0.9999)),
+ ),
+ "alpha": float(optim_args.get("alpha", 5.0)),
+ "eps": float(optim_args.get("eps", args.adam_epsilon)),
+ }
+
+ if "t_alpha" in optim_args:
+ additional_optim_kwargs["t_alpha"] = int(optim_args["t_alpha"])
+
+ if "t_beta3" in optim_args:
+ additional_optim_kwargs["t_beta3"] = int(optim_args["t_beta3"])
+
+ bnb_kwargs = {"optim_bits": optim_bits}
+ if "rmsprop" not in args.optim:
+ bnb_kwargs["is_paged"] = is_paged
+
+ optimizer_kwargs.update(additional_optim_kwargs)
+ optimizer_kwargs.update(bnb_kwargs)
+ except ImportError:
+ raise ValueError("Trainer tried to instantiate bnb optimizer but `bitsandbytes` is not installed!")
+ if is_bitsandbytes_available() and version.parse(
+ importlib.metadata.version("bitsandbytes")
+ ) < version.parse("0.41.1"):
+ logger.warning(
+ "You are using 8-bit optimizers with a version of `bitsandbytes` < 0.41.1. "
+ "It is recommended to update your version as a major bug has been fixed in 8-bit optimizers."
+ )
+ elif args.optim == OptimizerNames.ADAMW_ANYPRECISION:
+ try:
+ from torchdistx.optimizers import AnyPrecisionAdamW
+
+ optimizer_cls = AnyPrecisionAdamW
+ optimizer_kwargs.update(adam_kwargs)
+
+ # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx.
+ optimizer_kwargs.update(
+ {
+ "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")),
+ "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")),
+ "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")),
+ "compensation_buffer_dtype": getattr(
+ torch, optim_args.get("compensation_buffer_dtype", "bfloat16")
+ ),
+ }
+ )
+ except ImportError:
+ raise ValueError("Please install https://github.com/pytorch/torchdistx")
+ elif args.optim == OptimizerNames.SGD:
+ optimizer_cls = torch.optim.SGD
+ elif args.optim == OptimizerNames.ADAGRAD:
+ optimizer_cls = torch.optim.Adagrad
+ elif args.optim == OptimizerNames.RMSPROP:
+ optimizer_cls = torch.optim.RMSprop
+ elif args.optim in [
+ OptimizerNames.GALORE_ADAMW,
+ OptimizerNames.GALORE_ADAMW_8BIT,
+ OptimizerNames.GALORE_ADAFACTOR,
+ OptimizerNames.GALORE_ADAMW_LAYERWISE,
+ OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE,
+ OptimizerNames.GALORE_ADAFACTOR_LAYERWISE,
+ ]:
+ if not is_galore_torch_available():
+ raise ImportError(
+ "You need to install `galore_torch` in order to use GaLore optimizers"
+ " install it with `pip install git+https://github.com/jiaweizzhao/GaLore`"
+ )
+ from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
+
+ optimizer_mapping = {
+ OptimizerNames.GALORE_ADAMW: GaLoreAdamW,
+ OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit,
+ OptimizerNames.GALORE_ADAFACTOR: GaLoreAdafactor,
+ OptimizerNames.GALORE_ADAMW_LAYERWISE: GaLoreAdamW,
+ OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE: GaLoreAdamW8bit,
+ OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor,
+ }
+
+ galore_optim_kwargs = {
+ "rank": int(optim_args.pop("rank", 128)),
+ "update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
+ "scale": float(optim_args.pop("scale", 0.25)),
+ "proj_type": optim_args.pop("proj_type", "std"),
+ }
+
+ optimizer_cls, optimizer_kwargs = setup_low_rank_optimizer(
+ args.optim, optimizer_mapping, galore_optim_kwargs
+ )
+ if args.optim == OptimizerNames.GALORE_ADAFACTOR:
+ optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
+ elif args.optim in [
+ OptimizerNames.APOLLO_ADAMW,
+ OptimizerNames.APOLLO_ADAMW_LAYERWISE,
+ ]:
+ if not is_apollo_torch_available():
+ raise ImportError(
+ "You need to install `apollo_torch` in order to use APOLLO optimizers"
+ " install it with `pip install git+https://github.com/zhuhanqing/APOLLO`"
+ )
+ from apollo_torch import APOLLOAdamW
+
+ optimizer_mapping = {
+ OptimizerNames.APOLLO_ADAMW: APOLLOAdamW,
+ OptimizerNames.APOLLO_ADAMW_LAYERWISE: APOLLOAdamW,
+ }
+
+ apollo_optim_kwargs = {
+ "rank": int(optim_args.pop("rank", 128)),
+ "proj": optim_args.pop("proj", "random"),
+ "scale_type": optim_args.pop("scale_type", "channel"),
+ "update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
+ "scale": float(optim_args.pop("scale", 1.0)),
+ "proj_type": optim_args.pop("proj_type", "std"),
+ }
+ apollo_optim_kwargs.update(adam_kwargs)
+
+ optimizer_cls, optimizer_kwargs = setup_low_rank_optimizer(
+ args.optim, optimizer_mapping, apollo_optim_kwargs
+ )
+ elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
+ if not is_lomo_available():
+ raise ImportError(
+ "You need to install `lomo_optim` in order to use LOMO optimizers"
+ " install it with `pip install lomo-optim`"
+ )
+ if not is_accelerate_available("0.30.0"):
+ raise ImportError("You need to have `accelerate>=0.30.0` to be able to use LOMO optimizers")
+
+ if model is None:
+ raise ValueError("You need to pass a `model` in order to correctly initialize a LOMO optimizer.")
+
+ from lomo_optim import AdaLomo, Lomo
+
+ if "ada" in args.optim:
+ optimizer_cls = AdaLomo
+ else:
+ optimizer_cls = Lomo
+
+ optimizer_kwargs.update({"model": model})
+ elif args.optim == OptimizerNames.GROKADAMW:
+ if not is_grokadamw_available():
+ raise ValueError("Please install grokadamw with `pip install grokadamw`")
+
+ from grokadamw import GrokAdamW
+
+ optimizer_cls = GrokAdamW
+ optimizer_kwargs.update(
+ {
+ "alpha_init": float(optim_args.get("alpha_init", 0.98)),
+ "lamb": float(optim_args.get("lamb", 2.0)),
+ "gamma": float(optim_args.get("gamma", 0.1)),
+ "grokking_signal_decay_rate": float(optim_args.get("grokking_signal_decay_rate", 0.1)),
+ "gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)),
+ }
+ )
+ elif args.optim in [
+ OptimizerNames.ADAMW_TORCH_4BIT,
+ OptimizerNames.ADAMW_TORCH_8BIT,
+ ]:
+ if not is_torchao_available() or version.parse(importlib.metadata.version("torchao")) < version.parse(
+ "0.4.0"
+ ):
+ raise ImportError(
+ "You need to have `torchao>=0.4.0` in order to use torch 4-bit optimizers."
+ "Install it with `pip install torchao` or follow the instructions here: https://github.com/pytorch/ao"
+ )
+ if version.parse(importlib.metadata.version("torch")) <= version.parse("2.4"):
+ raise ImportError(
+ "You need to have `torch>2.4` in order to use torch 4-bit optimizers. "
+ "Install it with `pip install --upgrade torch` it is available on pipy. Otherwise, you need to install torch nightly."
+ )
+ if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.11.0"):
+ # https://github.com/pytorch/ao/pull/2159
+ from torchao.optim import AdamW4bit, AdamW8bit
+ else:
+ from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
+ if args.optim == OptimizerNames.ADAMW_TORCH_4BIT:
+ optimizer_cls = AdamW4bit
+ elif args.optim == OptimizerNames.ADAMW_TORCH_8BIT:
+ optimizer_cls = AdamW8bit
+ else:
+ raise ValueError("Invalid optimizer")
+ optimizer_kwargs.update(adam_kwargs)
+ elif args.optim in [
+ OptimizerNames.SCHEDULE_FREE_RADAM,
+ OptimizerNames.SCHEDULE_FREE_ADAMW,
+ OptimizerNames.SCHEDULE_FREE_SGD,
+ ]:
+ if not is_schedulefree_available():
+ raise ImportError(
+ "You need to install `schedulefree` in order to use schedulefree optimizers. "
+ "Install it with `pip install schedulefree.`"
+ )
+ if not is_accelerate_available("0.30.0"):
+ raise ImportError("You need to have `accelerate>=0.30.0` to be able to use schedulefree optimizers")
+ from schedulefree import AdamWScheduleFree, SGDScheduleFree
+
+ additional_optim_kwargs = {}
+ require_warmup = True
+
+ if args.optim == OptimizerNames.SCHEDULE_FREE_RADAM:
+ if not is_schedulefree_available("1.4.0"):
+ raise ImportError(
+ "You need to install `schedulefree>=1.4.0` in order to use RAdamScheduleFree optimizer. "
+ "Install it with `pip install schedulefree.`"
+ )
+ from schedulefree import RAdamScheduleFree
+
+ optimizer_cls = RAdamScheduleFree
+ additional_optim_kwargs = adam_kwargs
+ require_warmup = False
+ elif args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW:
+ optimizer_cls = AdamWScheduleFree
+ additional_optim_kwargs = adam_kwargs
+ elif args.optim == OptimizerNames.SCHEDULE_FREE_SGD:
+ optimizer_cls = SGDScheduleFree
+ else:
+ raise ValueError("Invalid schedulefree optimizer")
+
+ additional_optim_kwargs["weight_decay"] = args.weight_decay
+ if require_warmup:
+ additional_optim_kwargs["warmup_steps"] = args.warmup_steps
+ additional_optim_kwargs.update(
+ {
+ "weight_lr_power": float(optim_args.get("weight_lr_power", 2.0)),
+ "r": float(optim_args.get("r", 0.0)),
+ }
+ )
+ optimizer_kwargs.update(additional_optim_kwargs)
+ elif args.optim == OptimizerNames.STABLE_ADAMW:
+ if not is_torch_optimi_available():
+ raise ImportError(
+ "You need to install `torch-optimi` in order to use stable_adamw optimizers. "
+ "Install it with `pip install torch-optimi`."
+ )
+ from optimi import StableAdamW
+
+ max_lr = optim_args.pop("max_lr", None)
+ if max_lr is not None:
+ max_lr = float(max_lr)
+
+ kahan_sum = optim_args.pop("kahan_sum", None)
+ if kahan_sum is not None:
+ kahan_sum = bool(kahan_sum)
+
+ adam_kwargs["weight_decay"] = args.weight_decay
+ stable_adamw_kwargs = {
+ "decouple_lr": bool(optim_args.pop("decouple_lr", False)),
+ "max_lr": max_lr,
+ "kahan_sum": kahan_sum,
+ }
+
+ optimizer_cls = StableAdamW
+ optimizer_kwargs.update(adam_kwargs)
+ optimizer_kwargs.update(stable_adamw_kwargs)
+ else:
+ raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
+ return optimizer_cls, optimizer_kwargs
+
+ def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
+ """
+ Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
+ passed as an argument.
+
+ Args:
+ num_training_steps (int): The number of training steps to do.
+ """
+ if self.lr_scheduler is None:
+ self.lr_scheduler = get_scheduler(
+ self.args.lr_scheduler_type,
+ optimizer=self.optimizer if optimizer is None else optimizer,
+ num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
+ num_training_steps=num_training_steps,
+ scheduler_specific_kwargs=self.args.lr_scheduler_kwargs,
+ )
+ self._created_lr_scheduler = True
+ return self.lr_scheduler
+
+ def num_examples(self, dataloader: DataLoader) -> int:
+ """
+ Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
+ dataloader.dataset does not exist or has no length, estimates as best it can
+ """
+ try:
+ dataset = dataloader.dataset
+ # Special case for IterableDatasetShard, we need to dig deeper
+ if isinstance(dataset, IterableDatasetShard):
+ return len(dataloader.dataset.dataset)
+ return len(dataloader.dataset)
+ except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader
+ return len(dataloader) * self.args.per_device_train_batch_size
+
+ @staticmethod
+ def num_tokens(train_dl: DataLoader, max_steps: Optional[int] = None) -> int:
+ """
+ Helper to get number of tokens in a [`~torch.utils.data.DataLoader`] by enumerating dataloader.
+ """
+ train_tokens = 0
+ try:
+ for batch in train_dl:
+ tokens = batch["input_ids"].numel()
+ if max_steps is not None:
+ return tokens * max_steps
+ train_tokens += tokens
+ except KeyError:
+ logger.warning("Cannot get num_tokens from dataloader")
+ return train_tokens
+
+ def _hp_search_setup(self, trial: Union["optuna.Trial", dict[str, Any]]):
+ """HP search setup code"""
+ self._trial = trial
+
+ if self.hp_search_backend is None or trial is None:
+ return
+ if self.hp_search_backend == HPSearchBackend.OPTUNA:
+ params = self.hp_space(trial)
+ elif self.hp_search_backend == HPSearchBackend.RAY:
+ params = trial
+ params.pop("wandb", None)
+ elif self.hp_search_backend == HPSearchBackend.SIGOPT:
+ params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
+ elif self.hp_search_backend == HPSearchBackend.WANDB:
+ params = trial
+
+ for key, value in params.items():
+ if not hasattr(self.args, key):
+ logger.warning(
+ f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
+ " `TrainingArguments`."
+ )
+ continue
+ old_attr = getattr(self.args, key, None)
+ # Casting value to the proper type
+ if old_attr is not None:
+ value = type(old_attr)(value)
+
+ setattr(self.args, key, value)
+ if self.hp_search_backend == HPSearchBackend.OPTUNA:
+ logger.info(f"Trial: {trial.params}")
+ if self.hp_search_backend == HPSearchBackend.SIGOPT:
+ logger.info(f"SigOpt Assignments: {trial.assignments}")
+ if self.hp_search_backend == HPSearchBackend.WANDB:
+ logger.info(f"W&B Sweep parameters: {trial}")
+ if self.is_deepspeed_enabled:
+ if self.args.deepspeed is None:
+ raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set")
+
+ self.accelerator.free_memory()
+
+ # Rebuild the deepspeed config to reflect the updated training parameters
+ from accelerate.utils import DeepSpeedPlugin
+
+ from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
+
+ self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
+ self.args.hf_deepspeed_config.trainer_config_process(self.args)
+ self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
+
+ # From 1.0 on, we need to fully wipe the DS plugin when doing sweeps.
+ # Simply calling `_reset_state` is enough and doesn't need a version pin.
+ AcceleratorState()._reset_state()
+
+ self.create_accelerator_and_postprocess()
+
+ def _report_to_hp_search(self, trial: Union["optuna.Trial", dict[str, Any]], step: int, metrics: dict[str, float]):
+ if self.hp_search_backend is None or trial is None:
+ return
+ metrics = metrics.copy()
+ self.objective = self.compute_objective(metrics)
+ if self.hp_search_backend == HPSearchBackend.OPTUNA:
+ import optuna
+
+ if hasattr(trial, "study") and not trial.study._is_multi_objective():
+ trial.report(self.objective, step)
+ if trial.should_prune():
+ self.callback_handler.on_train_end(self.args, self.state, self.control)
+ raise optuna.TrialPruned()
+ elif self.hp_search_backend == HPSearchBackend.RAY:
+ import ray.train
+
+ with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
+ checkpoint = None
+ if self.control.should_save:
+ self._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
+ checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
+ metrics["objective"] = self.objective
+ ray.train.report(metrics, checkpoint=checkpoint)
+
+ def _tune_save_checkpoint(self, checkpoint_dir: str):
+ output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
+ self.save_model(output_dir, _internal_call=True)
+ if self.args.should_save:
+ # Update the `TrainerControl` state to where we are currently
+ self.state.stateful_callbacks["TrainerControl"] = self.control.state()
+ self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
+ torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
+ torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
+
+ def call_model_init(self, trial=None):
+ model_init_argcount = number_of_arguments(self.model_init)
+ if model_init_argcount == 0:
+ model = self.model_init()
+ elif model_init_argcount == 1:
+ model = self.model_init(trial)
+ else:
+ raise RuntimeError("model_init should have 0 or 1 argument.")
+
+ if model is None:
+ raise RuntimeError("model_init should not return None.")
+
+ return model
+
+ def torch_jit_model_eval(self, model, dataloader, training=False):
+ if not training:
+ if dataloader is None:
+ logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
+ return model
+ example_batch = next(iter(dataloader))
+ example_batch = self._prepare_inputs(example_batch)
+ try:
+ jit_model = copy.copy(model)
+ jit_model.eval()
+ original_forward = jit_model.__dict__.pop("_original_forward", None)
+ # remove mixed precision hooks from the model
+ if original_forward:
+ jit_model.forward = original_forward
+ autocast_handler = AutocastKwargs(cache_enabled=False)
+ with self.accelerator.autocast(autocast_handler=autocast_handler), torch.no_grad():
+ if isinstance(example_batch, dict):
+ jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
+ else:
+ jit_model = torch.jit.trace(
+ jit_model,
+ example_kwarg_inputs={key: example_batch[key] for key in example_batch},
+ strict=False,
+ )
+ jit_model = torch.jit.freeze(jit_model)
+ with torch.no_grad():
+ jit_model(**example_batch)
+ jit_model(**example_batch)
+ model = jit_model
+ self.use_cpu_amp = False
+ except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
+ logger.warning(f"failed to use PyTorch jit mode due to: {e}.")
+
+ return model
+
+ def compare_trainer_and_checkpoint_args(self, training_args, trainer_state):
+ attributes_map = {
+ "logging_steps": "logging_steps",
+ "eval_steps": "eval_steps",
+ "save_steps": "save_steps",
+ }
+
+ has_warning = False
+ warning_str = "Warning: The following arguments do not match the ones in the `trainer_state.json` within the checkpoint directory: "
+ for arg_attr, state_attr in attributes_map.items():
+ arg_value = getattr(training_args, arg_attr, None)
+ state_value = getattr(trainer_state, state_attr, None)
+
+ if arg_value is not None and state_value is not None and arg_value != state_value:
+ warning_str += f"\n\t{arg_attr}: {arg_value} (from args) != {state_value} (from trainer_state.json)"
+ has_warning = True
+
+ # train bs is special as we need to account for multi-GPU
+ train_bs_args = training_args.per_device_train_batch_size
+ train_bs_state = trainer_state.train_batch_size // max(1, training_args.n_gpu)
+
+ if train_bs_args != train_bs_state:
+ warning_str += f"\n\tper_device_train_batch_size: {train_bs_args} (from args) != {train_bs_state} (from trainer_state.json)"
+ has_warning = True
+
+ if has_warning:
+ logger.warning_once(warning_str)
+
+ def _wrap_model(self, model, training=True, dataloader=None):
+ if is_sagemaker_mp_enabled():
+ # Wrapping the base model twice in a DistributedModel will raise an error.
+ if isinstance(self.model_wrapped, smp.model.DistributedModel):
+ return self.model_wrapped
+ return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
+
+ # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
+ if self.accelerator.unwrap_model(model, keep_torch_compile=False) is not model:
+ return model
+
+ # Mixed precision training with apex
+ if self.use_apex and training:
+ from apex import amp
+
+ model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
+
+ # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP
+ if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False):
+ model = nn.DataParallel(model)
+
+ if self.args.jit_mode_eval:
+ start_time = time.time()
+ model = self.torch_jit_model_eval(model, dataloader, training)
+ self.jit_compilation_time = round(time.time() - start_time, 4)
+
+ # Note: in torch.distributed mode, there's no point in wrapping the model
+ # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
+ if not training:
+ return model
+
+ # Distributed training (should be after apex fp16 initialization)
+ # Distributed training using PyTorch FSDP
+ if self.is_fsdp_xla_enabled:
+ try:
+ from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
+ from torch_xla.distributed.fsdp import checkpoint_module
+ from torch_xla.distributed.fsdp.wrap import (
+ size_based_auto_wrap_policy,
+ transformer_auto_wrap_policy,
+ )
+
+ if self.is_fsdp_xla_v2_enabled:
+ from torch_xla.experimental.spmd_fully_sharded_data_parallel import (
+ SpmdFullyShardedDataParallel as FSDPv2,
+ )
+ except ImportError:
+ raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
+ auto_wrap_policy = None
+ auto_wrapper_callable = None
+ default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
+ fsdp_transformer_layer_cls_to_wrap = self.args.fsdp_config.get(
+ "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
+ )
+
+ if self.args.fsdp_config["min_num_params"] > 0:
+ auto_wrap_policy = functools.partial(
+ size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["min_num_params"]
+ )
+ elif fsdp_transformer_layer_cls_to_wrap is not None:
+ transformer_cls_to_wrap = set()
+ for layer_class in fsdp_transformer_layer_cls_to_wrap:
+ transformer_cls = get_module_class_from_name(model, layer_class)
+ if transformer_cls is None:
+ raise Exception("Could not find the transformer layer class to wrap in the model.")
+ else:
+ transformer_cls_to_wrap.add(transformer_cls)
+
+ auto_wrap_policy = functools.partial(
+ transformer_auto_wrap_policy,
+ # Transformer layer class to wrap
+ transformer_layer_cls=transformer_cls_to_wrap,
+ )
+ fsdp_kwargs = self.args.xla_fsdp_config
+ if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
+ if model.config.use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ model.config.use_cache = False
+
+ # Apply gradient checkpointing to auto-wrapped sub-modules if specified
+ def auto_wrapper_callable(m, *args, **kwargs):
+ target_cls = FSDP if not self.is_fsdp_xla_v2_enabled else FSDPv2
+ return target_cls(checkpoint_module(m), *args, **kwargs)
+
+ # Wrap the base model with an outer FSDP wrapper
+ if self.is_fsdp_xla_v2_enabled:
+
+ def shard_output(output, mesh):
+ from .modeling_outputs import CausalLMOutputWithPast
+
+ real_output = None
+ if isinstance(output, torch.Tensor):
+ real_output = output
+ elif isinstance(output, tuple):
+ real_output = output[0]
+ elif isinstance(output, CausalLMOutputWithPast):
+ real_output = output.logits
+
+ if real_output is None:
+ raise ValueError("Something went wrong, the output of the model shouldn't be `None`")
+ xs.mark_sharding(real_output, mesh, ("fsdp", None, None))
+
+ self.model = model = FSDPv2(
+ model,
+ shard_output=shard_output,
+ auto_wrap_policy=auto_wrap_policy,
+ auto_wrapper_callable=auto_wrapper_callable,
+ )
+ else:
+ self.model = model = FSDP(
+ model,
+ auto_wrap_policy=auto_wrap_policy,
+ auto_wrapper_callable=auto_wrapper_callable,
+ **fsdp_kwargs,
+ )
+
+ # Patch `xm.optimizer_step` should not reduce gradients in this case,
+ # as FSDP does not need gradient reduction over sharded parameters.
+ def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
+ loss = optimizer.step(**optimizer_args)
+ if barrier:
+ xm.mark_step()
+ return loss
+
+ xm.optimizer_step = patched_optimizer_step
+ elif is_sagemaker_dp_enabled():
+ model = nn.parallel.DistributedDataParallel(
+ model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
+ )
+ elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+ if is_torch_neuroncore_available():
+ return model
+ kwargs = {}
+ if self.args.ddp_find_unused_parameters is not None:
+ kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
+ elif isinstance(model, PreTrainedModel):
+ # find_unused_parameters breaks checkpointing as per
+ # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
+ kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
+ else:
+ kwargs["find_unused_parameters"] = True
+
+ if self.args.ddp_bucket_cap_mb is not None:
+ kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
+
+ if self.args.ddp_broadcast_buffers is not None:
+ kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers
+
+ self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
+
+ return model
+
+ def train(
+ self,
+ resume_from_checkpoint: Optional[Union[str, bool]] = None,
+ trial: Union["optuna.Trial", dict[str, Any], None] = None,
+ ignore_keys_for_eval: Optional[list[str]] = None,
+ **kwargs: Any,
+ ):
+ """
+ Main training entry point.
+
+ Args:
+ resume_from_checkpoint (`str` or `bool`, *optional*):
+ If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
+ `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
+ of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
+ trial (`optuna.Trial` or `dict[str, Any]`, *optional*):
+ The trial run or the hyperparameter dictionary for hyperparameter search.
+ ignore_keys_for_eval (`list[str]`, *optional*)
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+ gathering predictions for evaluation during the training.
+ kwargs (`dict[str, Any]`, *optional*):
+ Additional keyword arguments used to hide deprecated arguments
+ """
+ if resume_from_checkpoint is False:
+ resume_from_checkpoint = None
+
+ # memory metrics - must set up as early as possible
+ self._memory_tracker.start()
+
+ args = self.args
+
+ self.is_in_train = True
+
+ # If the model uses a tokenizer, it may have a new tokens for fine-tuning purposes.
+ if isinstance(self.processing_class, (PreTrainedTokenizerBase, ProcessorMixin)) and hasattr(
+ self.model, "config"
+ ):
+ self._align_special_tokens()
+
+ # Attach NEFTune hooks if necessary
+ if self.neftune_noise_alpha is not None:
+ self.model = self._activate_neftune(self.model)
+
+ # do_train is not a reliable argument, as it might not be set and .train() still called, so
+ # the following is a workaround:
+ if (
+ (args.fp16_full_eval or args.bf16_full_eval)
+ and not args.do_train
+ and not self.is_model_parallel
+ and self.model_init is None
+ ):
+ self._move_model_to_device(self.model, args.device)
+
+ if "model_path" in kwargs:
+ resume_from_checkpoint = kwargs.pop("model_path")
+ warnings.warn(
+ "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
+ "instead.",
+ FutureWarning,
+ )
+ if len(kwargs) > 0:
+ raise TypeError(f"train() got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
+ # This might change the seed so needs to run first.
+ self._hp_search_setup(trial)
+ self._train_batch_size = self.args.train_batch_size
+
+ # Model re-init
+ model_reloaded = False
+ if self.model_init is not None:
+ # Seed must be set before instantiating the model when using model_init.
+ enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
+ self.model = self.call_model_init(trial)
+ model_reloaded = True
+ # Reinitializes optimizer and scheduler
+ self.optimizer, self.lr_scheduler = None, None
+
+ # Load potential model checkpoint
+ if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
+ resume_from_checkpoint = get_last_checkpoint(args.output_dir)
+ if resume_from_checkpoint is None:
+ raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
+
+ if resume_from_checkpoint is not None:
+ if not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled and not self.is_fsdp_enabled:
+ self._load_from_checkpoint(resume_from_checkpoint)
+ # In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly
+ state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
+ if state.train_batch_size is not None:
+ self._train_batch_size = state.train_batch_size
+
+ # If model was re-initialized, put it on the right device and update self.model_wrapped
+ if model_reloaded:
+ if self.place_model_on_device:
+ self._move_model_to_device(self.model, args.device)
+ self.model_wrapped = self.model
+
+ inner_training_loop = find_executable_batch_size(
+ self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
+ )
+ if args.push_to_hub:
+ try:
+ # Disable progress bars when uploading models during checkpoints to avoid polluting stdout
+ hf_hub_utils.disable_progress_bars()
+ return inner_training_loop(
+ args=args,
+ resume_from_checkpoint=resume_from_checkpoint,
+ trial=trial,
+ ignore_keys_for_eval=ignore_keys_for_eval,
+ )
+ finally:
+ hf_hub_utils.enable_progress_bars()
+ else:
+ return inner_training_loop(
+ args=args,
+ resume_from_checkpoint=resume_from_checkpoint,
+ trial=trial,
+ ignore_keys_for_eval=ignore_keys_for_eval,
+ )
+
+ def get_tp_size(self) -> int:
+ """Get the tensor parallel size from either the model or DeepSpeed config."""
+
+ # 1. Check model.tp_size first
+ if (model_tp := getattr(self.model, "_tp_size", None)) is not None:
+ return model_tp
+
+ # 2. Fall back to DeepSpeed config if enabled
+ if self.is_deepspeed_enabled and (deepspeed_config := getattr(self.args, "hf_deepspeed_config", None)):
+ return deepspeed_config.config.get("tensor_parallel", {}).get("autotp_size", 1)
+
+ # 3. Default fallback
+ return 1
+
+ def get_total_train_batch_size(self, args) -> int:
+ """Calculates total batch size (micro_batch * grad_accum * dp_world_size).
+
+ Note: Only considers DP and TP (dp_world_size = world_size // tp_size)."""
+ dp_world_size = args.world_size // self.get_tp_size()
+ return self._train_batch_size * args.gradient_accumulation_steps * dp_world_size
+
+ def _inner_training_loop(
+ self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
+ ):
+ self.accelerator.free_memory()
+ self._train_batch_size = batch_size
+ if self.args.auto_find_batch_size:
+ if self.state.train_batch_size != self._train_batch_size:
+ from accelerate.utils import release_memory
+
+ (self.model_wrapped,) = release_memory(self.model_wrapped)
+ self.model_wrapped = self.model
+
+ # Check for DeepSpeed *after* the initial pass and modify the config
+ if self.is_deepspeed_enabled:
+ # Temporarily unset `self.args.train_batch_size`
+ original_bs = self.args.per_device_train_batch_size
+ self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu)
+ self.propagate_args_to_deepspeed(True)
+ self.args.per_device_train_batch_size = original_bs
+ self.state.train_batch_size = self._train_batch_size
+ logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
+ # Data loader and number of training steps
+ train_dataloader = self.get_train_dataloader()
+ if self.is_fsdp_xla_v2_enabled:
+ train_dataloader = tpu_spmd_dataloader(train_dataloader)
+
+ # Setting up training control variables:
+ # number of training epochs: num_train_epochs
+ # number of training steps per epoch: num_update_steps_per_epoch
+ # total number of training steps to execute: max_steps
+ total_train_batch_size = self.get_total_train_batch_size(args)
+
+ (
+ num_train_epochs,
+ num_update_steps_per_epoch,
+ num_examples,
+ num_train_samples,
+ epoch_based,
+ len_dataloader,
+ max_steps,
+ ) = self.set_initial_training_values(args, train_dataloader, total_train_batch_size)
+
+ num_train_tokens = None
+ if self.args.include_tokens_per_second:
+ num_train_tokens = self.num_tokens(train_dataloader, None if epoch_based else max_steps)
+ # If going by epochs, multiply tokens linearly
+ if len_dataloader is not None and epoch_based:
+ num_train_tokens *= args.num_train_epochs
+ # Otherwise since its steps, we just multiply by grad accum
+ else:
+ num_train_tokens *= args.gradient_accumulation_steps
+
+ if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
+ if self.args.n_gpu > 1:
+ # nn.DataParallel(model) replicates the model, creating new variables and module
+ # references registered here no longer work on other gpus, breaking the module
+ raise ValueError(
+ "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
+ " (torchrun or torch.distributed.launch (deprecated))."
+ )
+ else:
+ DebugUnderflowOverflow(self.model)
+
+ delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
+
+ # Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404
+ is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2)
+ if is_fsdp2:
+ delay_optimizer_creation = False
+
+ # We need to reset the scheduler, as its parameters may be different on subsequent calls
+ if self._created_lr_scheduler:
+ self.lr_scheduler = None
+ self._created_lr_scheduler = False
+
+ if self.is_deepspeed_enabled:
+ self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)
+
+ if not delay_optimizer_creation:
+ self.create_optimizer_and_scheduler(num_training_steps=max_steps)
+
+ self.state = TrainerState(
+ stateful_callbacks=[
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
+ ]
+ )
+ self.state.is_hyper_param_search = trial is not None
+ self.state.train_batch_size = self._train_batch_size
+
+ # Compute absolute values for logging, eval, and save if given as ratio
+ self.state.compute_steps(args, max_steps)
+
+ # Activate gradient checkpointing if needed
+ if args.gradient_checkpointing:
+ self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs)
+
+ model = self._wrap_model(self.model_wrapped)
+
+ # as the model is wrapped, don't use `accelerator.prepare`
+ # this is for unhandled cases such as
+ # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
+ use_accelerator_prepare = model is self.model
+
+ if use_accelerator_prepare and self.is_fsdp_enabled:
+ # In case of auto_find_batch_size=True
+ # Remove FSDP wrapping from sub-models.
+ self.model = unwrap_model(self.model, recursive=True)
+
+ if delay_optimizer_creation:
+ if use_accelerator_prepare:
+ # configure fsdp plugin for qlora if any
+ self._fsdp_qlora_plugin_updates()
+ if self.accelerator.mixed_precision != "fp8":
+ self.model = self.accelerator.prepare(self.model)
+ self.create_optimizer_and_scheduler(num_training_steps=max_steps)
+
+ # prepare using `accelerator` prepare
+ if use_accelerator_prepare:
+ self.model.train()
+ if hasattr(self.lr_scheduler, "step"):
+ if self.use_apex:
+ model = self.accelerator.prepare(self.model)
+ else:
+ # We should avoid accelerate preparing the model in TP case since we dont need it as it is handled by transformers from_pretrained and also it goes into DDP based preparation.
+ if self.is_tp_enabled:
+ self.optimizer = self.accelerator.prepare(self.optimizer)
+ else:
+ model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
+ else:
+ # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
+ model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
+ self.model, self.optimizer, self.lr_scheduler
+ )
+ else:
+ self.optimizer = self.accelerator.prepare(self.optimizer)
+
+ if self.is_fsdp_enabled:
+ self.model = self.model_wrapped = model
+
+ # for the rest of this function `model` is the outside model, whether it was wrapped or not
+ if model is not self.model:
+ self.model_wrapped = model
+
+ # backward compatibility
+ if self.is_deepspeed_enabled:
+ self.deepspeed = self.model_wrapped
+
+ # ckpt loading
+ if resume_from_checkpoint is not None:
+ if self.is_deepspeed_enabled:
+ deepspeed_load_checkpoint(
+ self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model)
+ )
+ elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
+ self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
+
+ # Check if saved optimizer or scheduler states exist
+ self._load_optimizer_and_scheduler(resume_from_checkpoint)
+ self._load_scaler(resume_from_checkpoint)
+
+ # important: at this point:
+ # self.model is the Transformers Model
+ # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
+ # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.
+
+ # Train!
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {num_examples:,}")
+ logger.info(f" Num Epochs = {num_train_epochs:,}")
+ logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
+ if self.args.per_device_train_batch_size != self._train_batch_size:
+ logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {max_steps:,}")
+ logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
+
+ self.state.epoch = 0
+ start_time = time.time()
+ epochs_trained = 0
+ steps_trained_in_current_epoch = 0
+
+ # Check if continuing training from a checkpoint
+ if resume_from_checkpoint is not None and os.path.isfile(
+ os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
+ ):
+ self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
+ self.compare_trainer_and_checkpoint_args(self.args, self.state)
+ self._load_callback_state()
+ epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
+ if not args.ignore_data_skip:
+ steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
+ steps_trained_in_current_epoch *= args.gradient_accumulation_steps
+ else:
+ steps_trained_in_current_epoch = 0
+
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
+ logger.info(f" Continuing training from epoch {epochs_trained}")
+ logger.info(f" Continuing training from global step {self.state.global_step}")
+ if not args.ignore_data_skip:
+ logger.info(
+ f" Will skip the first {epochs_trained} epochs then the first"
+ f" {steps_trained_in_current_epoch} batches in the first epoch."
+ )
+
+ # Update the references
+ for attr in ("model", "optimizer", "lr_scheduler"):
+ setattr(self.callback_handler, attr, getattr(self, attr))
+ self.callback_handler.train_dataloader = train_dataloader
+
+ self.state.init_training_references(self, max_steps, num_train_epochs, trial)
+
+ # tr_loss is a tensor to avoid synchronization of TPUs through .item()
+ tr_loss = torch.tensor(0.0, device=args.device)
+ # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
+ self._total_loss_scalar = 0.0
+ self._globalstep_last_logged = self.state.global_step
+ model.zero_grad()
+ grad_norm: Optional[float] = None
+ learning_rate = None
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
+
+ if args.eval_on_start:
+ self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
+
+ for epoch in range(epochs_trained, num_train_epochs):
+ epoch_dataloader = train_dataloader
+ if hasattr(epoch_dataloader, "set_epoch"):
+ epoch_dataloader.set_epoch(epoch)
+
+ # Reset the past mems state at the beginning of each epoch if necessary.
+ if args.past_index >= 0:
+ self._past = None
+
+ steps_in_epoch = (
+ len(epoch_dataloader)
+ if len_dataloader is not None
+ else args.max_steps * args.gradient_accumulation_steps
+ )
+ self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
+
+ step = -1
+ rng_to_sync = False
+
+ # Handle resumption from checkpoint
+ if epoch == epochs_trained and resume_from_checkpoint is not None:
+ if steps_trained_in_current_epoch > 0 and not args.ignore_data_skip:
+ epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
+ step = steps_trained_in_current_epoch - 1
+ rng_to_sync = True
+ elif steps_trained_in_current_epoch == 0:
+ self._load_rng_state(resume_from_checkpoint)
+
+ epoch_iterator = iter(epoch_dataloader)
+ # We chunkify the epoch iterator into gradient accumulation steps `n` batches
+ remainder = steps_in_epoch % args.gradient_accumulation_steps
+ if remainder == 0:
+ remainder = args.gradient_accumulation_steps
+ update_step = -1
+ total_updates = steps_in_epoch // args.gradient_accumulation_steps + int(
+ remainder < args.gradient_accumulation_steps
+ )
+ for _ in range(total_updates):
+ update_step += 1
+ num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
+ batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device)
+ # Store the number of batches for current gradient accumulation
+ # This is used to correctly scale the loss when the last accumulation step has fewer batches
+ self.current_gradient_accumulation_steps = len(batch_samples)
+ for i, inputs in enumerate(batch_samples):
+ step += 1
+ do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
+ # Since we perform prefetching, we need to manually set sync_gradients
+ self.accelerator.gradient_state._set_sync_gradients(do_sync_step)
+
+ if self.args.include_num_input_tokens_seen not in ["no", False]:
+ main_input_name = getattr(self.model, "main_input_name", "input_ids")
+ if main_input_name not in inputs:
+ logger.warning(
+ "Tried to track the number of tokens seen, however the current model is "
+ "not configured properly to know what item is the input. To fix this, add "
+ "a `main_input_name` attribute to the model class you are using."
+ )
+ else:
+ if self.args.include_num_input_tokens_seen == "non_padding":
+ if "attention_mask" in inputs:
+ input_tokens = inputs["attention_mask"].sum()
+ elif (
+ self.processing_class is not None
+ and hasattr(self.processing_class, "pad_token_id")
+ and self.processing_class.pad_token_id is not None
+ ):
+ input_tokens = (
+ inputs[main_input_name] != self.processing_class.pad_token_id
+ ).sum()
+ else:
+ logger.warning(
+ "Could not determine method to count non-padding tokens, falling back to counting all tokens."
+ )
+ input_tokens = inputs[main_input_name].numel()
+ else:
+ input_tokens = inputs[main_input_name].numel()
+
+ input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64)
+ self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item()
+
+ if rng_to_sync:
+ self._load_rng_state(resume_from_checkpoint)
+ rng_to_sync = False
+
+ if step % args.gradient_accumulation_steps == 0:
+ self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
+
+ # We explicitly want to avoid relying on `accelerator.accumulate` for generation training
+ context = (
+ functools.partial(self.accelerator.no_sync, model=model)
+ if i != len(batch_samples) - 1
+ and self.accelerator.distributed_type != DistributedType.DEEPSPEED
+ else contextlib.nullcontext
+ )
+ with context():
+ tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
+
+ if (
+ args.logging_nan_inf_filter
+ and not is_torch_xla_available()
+ and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
+ ):
+ # if loss is nan or inf simply add the average of previous logged losses
+ tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
+ else:
+ if tr_loss.device != tr_loss_step.device:
+ raise ValueError(
+ f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
+ )
+ tr_loss = tr_loss + tr_loss_step
+
+ self.current_flos += float(self.floating_point_ops(inputs))
+
+ if do_sync_step:
+ # Since we perform prefetching, we need to manually set sync_gradients to True
+ self.accelerator.gradient_state._set_sync_gradients(True)
+
+ # Gradient clipping
+ if args.max_grad_norm is not None and args.max_grad_norm > 0:
+ if is_sagemaker_mp_enabled() and args.fp16:
+ _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
+ elif self.use_apex:
+ from apex import amp
+
+ # Revert to normal clipping otherwise, handling Apex or full precision
+ _grad_norm = nn.utils.clip_grad_norm_(
+ amp.master_params(self.optimizer),
+ args.max_grad_norm,
+ )
+ else:
+ grad_norm_context = contextlib.nullcontext
+ if self.is_tp_enabled:
+ from torch.distributed._tensor.experimental import implicit_replication
+
+ grad_norm_context = implicit_replication
+ with grad_norm_context():
+ _grad_norm = self.accelerator.clip_grad_norm_(
+ model.parameters(),
+ args.max_grad_norm,
+ )
+
+ if (
+ is_accelerate_available()
+ and self.accelerator.distributed_type == DistributedType.DEEPSPEED
+ ):
+ grad_norm = model.get_global_grad_norm()
+ # In some cases the grad norm may not return a float
+ if hasattr(grad_norm, "item"):
+ grad_norm = grad_norm.item()
+ else:
+ grad_norm = _grad_norm
+
+ self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)
+
+ context = contextlib.nullcontext
+ if self.is_tp_enabled:
+ from torch.distributed._tensor.experimental import implicit_replication
+
+ context = implicit_replication
+
+ with context():
+ self.optimizer.step()
+
+ self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
+
+ # get leaning rate before update
+ learning_rate = self._get_learning_rate()
+
+ if not self.accelerator.optimizer_step_was_skipped:
+ # Delay optimizer scheduling until metrics are generated
+ if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
+ self.lr_scheduler.step()
+
+ model.zero_grad()
+ self.state.global_step += 1
+ self.state.epoch = epoch + (step + 1) / steps_in_epoch
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
+ self._maybe_log_save_evaluate(
+ tr_loss,
+ grad_norm,
+ model,
+ trial,
+ epoch,
+ ignore_keys_for_eval,
+ start_time,
+ learning_rate=learning_rate,
+ )
+ else:
+ self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
+
+ # PyTorch/XLA relies on the data loader to insert the mark_step for
+ # each step. Since we are breaking the loop early, we need to manually
+ # insert the mark_step here.
+ if self.control.should_epoch_stop or self.control.should_training_stop:
+ if is_torch_xla_available():
+ xm.mark_step()
+ break
+ # We also need to break out of the nested loop
+ if self.control.should_epoch_stop or self.control.should_training_stop:
+ if is_torch_xla_available():
+ xm.mark_step()
+ break
+ if step < 0:
+ logger.warning(
+ "There seems not to be a single sample in your epoch_iterator, stopping training at step"
+ f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
+ f" num_steps ({max_steps}) higher than the number of available samples."
+ )
+ self.control.should_training_stop = True
+
+ self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
+ self._maybe_log_save_evaluate(
+ tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate
+ )
+
+ if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
+ if is_torch_xla_available():
+ # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
+ xm.master_print(met.metrics_report())
+ else:
+ logger.warning(
+ "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
+ "configured. Check your training configuration if this is unexpected."
+ )
+ if self.control.should_training_stop:
+ break
+
+ if args.past_index and hasattr(self, "_past"):
+ # Clean the state at the end of training
+ delattr(self, "_past")
+
+ logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
+ if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
+ self._load_best_model()
+
+ # add remaining tr_loss
+ self._total_loss_scalar += tr_loss.item()
+ effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError
+ train_loss = self._total_loss_scalar / effective_global_step
+
+ metrics = speed_metrics(
+ "train",
+ start_time,
+ num_samples=num_train_samples,
+ num_steps=self.state.max_steps,
+ num_tokens=num_train_tokens,
+ )
+ self.store_flos()
+ metrics["total_flos"] = self.state.total_flos
+ metrics["train_loss"] = train_loss
+
+ self.is_in_train = False
+
+ self._memory_tracker.stop_and_update_metrics(metrics)
+
+ self.log(metrics)
+
+ run_dir = self._get_output_dir(trial)
+ checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
+
+ # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
+ if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
+ for checkpoint in checkpoints_sorted:
+ if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
+ shutil.rmtree(checkpoint, ignore_errors=True)
+
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
+
+ # Wait for the checkpoint to be uploaded.
+ self._finish_current_push()
+
+ # After training we make sure to retrieve back the original forward pass method
+ # for the embedding layer by removing the forward post hook.
+ if self.neftune_noise_alpha is not None:
+ self._deactivate_neftune(self.model)
+
+ return TrainOutput(self.state.global_step, train_loss, metrics)
+
+ def _get_output_dir(self, trial):
+ if self.hp_search_backend is not None and trial is not None:
+ if self.hp_search_backend == HPSearchBackend.OPTUNA:
+ run_id = trial.number
+ elif self.hp_search_backend == HPSearchBackend.RAY:
+ import ray.train
+
+ run_id = ray.train.get_context().get_trial_id()
+ elif self.hp_search_backend == HPSearchBackend.SIGOPT:
+ run_id = trial.id
+ elif self.hp_search_backend == HPSearchBackend.WANDB:
+ import wandb
+
+ run_id = wandb.run.id
+ run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
+ run_dir = os.path.join(self.args.output_dir, run_name)
+ else:
+ run_dir = self.args.output_dir
+ return run_dir
+
+ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
+ if model is None:
+ model = self.model
+
+ config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
+ adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)
+ adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
+ weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
+ weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
+ safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
+ safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
+ is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and (
+ # this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
+ any(
+ FSDP_MODEL_NAME in folder_name
+ for folder_name in os.listdir(resume_from_checkpoint)
+ if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
+ )
+ # this checks the FSDP state dict when `FULL_STATE_DICT` is used
+ or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin"))
+ )
+ # if multiple adapters exist, they get saved in sub directories
+ adapter_subdirs = (
+ [
+ folder_name
+ for folder_name in os.listdir(resume_from_checkpoint)
+ if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
+ and (
+ os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_WEIGHTS_NAME))
+ or os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_SAFE_WEIGHTS_NAME))
+ )
+ ]
+ if os.path.isdir(resume_from_checkpoint)
+ else []
+ )
+
+ if is_fsdp_ckpt and not self.is_fsdp_enabled:
+ raise ValueError(f"Checkpoint found at {resume_from_checkpoint} is only supported when using PyTorch FSDP")
+
+ if not (
+ any(
+ os.path.isfile(f)
+ for f in [
+ weights_file,
+ safe_weights_file,
+ weights_index_file,
+ safe_weights_index_file,
+ adapter_weights_file,
+ adapter_safe_weights_file,
+ ]
+ )
+ or is_fsdp_ckpt
+ or adapter_subdirs
+ ):
+ raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
+
+ logger.info(f"Loading model from {resume_from_checkpoint}.")
+
+ if os.path.isfile(config_file):
+ config = PretrainedConfig.from_json_file(config_file)
+ checkpoint_version = config.transformers_version
+ if checkpoint_version is not None and checkpoint_version != __version__:
+ logger.warning(
+ f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
+ f"Transformers but your current version is {__version__}. This is not recommended and could "
+ "yield to errors or unwanted behaviors."
+ )
+
+ if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
+ # If the model is on the GPU, it still works!
+ if is_sagemaker_mp_enabled():
+ if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
+ # If the 'user_content.pt' file exists, load with the new smp api.
+ # Checkpoint must have been saved with the new smp api.
+ smp.resume_from_checkpoint(
+ path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
+ )
+ else:
+ # If the 'user_content.pt' file does NOT exist, load with the old smp api.
+ # Checkpoint must have been saved with the old smp api.
+ if hasattr(self.args, "fp16") and self.args.fp16 is True:
+ logger.warning(
+ "Enabling FP16 and loading from smp < 1.10 checkpoint together is not supported."
+ )
+ check_torch_load_is_safe()
+ state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
+ # Required for smp to not auto-translate state_dict from hf to smp (is already smp).
+ state_dict["_smp_is_partial"] = False
+ load_result = model.load_state_dict(state_dict, strict=True)
+ # release memory
+ del state_dict
+ elif self.is_fsdp_enabled:
+ load_fsdp_model(
+ self.accelerator.state.fsdp_plugin,
+ self.accelerator,
+ model,
+ resume_from_checkpoint,
+ **_get_fsdp_ckpt_kwargs(),
+ )
+ else:
+ # We load the model state dict on the CPU to avoid an OOM error.
+ if self.args.save_safetensors and os.path.isfile(safe_weights_file):
+ state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
+ else:
+ check_torch_load_is_safe()
+ state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
+
+ # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
+ # which takes *args instead of **kwargs
+ load_result = model.load_state_dict(state_dict, False)
+ # release memory
+ del state_dict
+ self._issue_warnings_after_load(load_result)
+
+ # Load adapters following PR # 24096
+ elif _is_peft_model(model):
+ # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
+ # TODO: in the future support only specific min PEFT versions
+ if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr(
+ model, "load_adapter"
+ ):
+ if os.path.exists(resume_from_checkpoint):
+ # For BC for older PEFT versions
+ if hasattr(model, "active_adapters"):
+ active_adapters = model.active_adapters
+ if len(active_adapters) > 1:
+ logger.warning("Multiple active adapters detected will only consider the first adapter")
+ active_adapter = active_adapters[0]
+ else:
+ active_adapter = model.active_adapter
+
+ if adapter_subdirs:
+ for subdir_name in adapter_subdirs:
+ peft_id = os.path.join(resume_from_checkpoint, subdir_name)
+ model.load_adapter(peft_id, subdir_name, is_trainable=(subdir_name == active_adapter))
+ model.set_adapter(active_adapter)
+ else:
+ model.load_adapter(resume_from_checkpoint, active_adapter, is_trainable=True)
+ else:
+ logger.warning(
+ "The intermediate checkpoints of PEFT may not be saved correctly, "
+ f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
+ "Check some examples here: https://github.com/huggingface/peft/issues/96"
+ )
+ else:
+ logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
+ else:
+ # We load the sharded checkpoint
+ load_result = load_sharded_checkpoint(
+ model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors
+ )
+ if not is_sagemaker_mp_enabled():
+ self._issue_warnings_after_load(load_result)
+
+ def _load_best_model(self):
+ logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
+ best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
+ best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
+ best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)
+ best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
+
+ model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
+ if self.is_deepspeed_enabled:
+ deepspeed_load_checkpoint(
+ self.model_wrapped,
+ self.state.best_model_checkpoint,
+ load_module_strict=not _is_peft_model(self.model),
+ )
+ elif self.is_fsdp_enabled:
+ load_result = load_fsdp_model(
+ self.accelerator.state.fsdp_plugin,
+ self.accelerator,
+ model,
+ self.state.best_model_checkpoint,
+ **_get_fsdp_ckpt_kwargs(),
+ )
+ elif (
+ os.path.exists(best_model_path)
+ or os.path.exists(best_safe_model_path)
+ or os.path.exists(best_adapter_model_path)
+ or os.path.exists(best_safe_adapter_model_path)
+ ):
+ has_been_loaded = True
+ if is_sagemaker_mp_enabled():
+ if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
+ # If the 'user_content.pt' file exists, load with the new smp api.
+ # Checkpoint must have been saved with the new smp api.
+ smp.resume_from_checkpoint(
+ path=self.state.best_model_checkpoint,
+ tag=WEIGHTS_NAME,
+ partial=False,
+ load_optimizer=False,
+ )
+ else:
+ # If the 'user_content.pt' file does NOT exist, load with the old smp api.
+ # Checkpoint must have been saved with the old smp api.
+ if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
+ state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
+ else:
+ check_torch_load_is_safe()
+ state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
+
+ state_dict["_smp_is_partial"] = False
+ load_result = model.load_state_dict(state_dict, strict=True)
+ else:
+ if _is_peft_model(model):
+ # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
+ # TODO: in the future support only specific min PEFT versions
+ if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr(
+ model, "load_adapter"
+ ):
+ # For BC for older PEFT versions
+ if hasattr(model, "active_adapters"):
+ active_adapter = model.active_adapters[0]
+ if len(model.active_adapters) > 1:
+ logger.warning("Detected multiple active adapters, will only consider the first one")
+ else:
+ active_adapter = model.active_adapter
+
+ if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
+ try:
+ model.load_adapter(self.state.best_model_checkpoint, active_adapter)
+ except RuntimeError as exc:
+ if model.peft_config[active_adapter].is_prompt_learning:
+ # for context: https://github.com/huggingface/peft/issues/2256
+ msg = (
+ "When using prompt learning PEFT methods such as "
+ f"{model.peft_config[active_adapter].peft_type.value}, setting "
+ "load_best_model_at_end=True can lead to errors, it is recommended "
+ "to set this to False and to load the model manually from the checkpoint "
+ "directory using PeftModel.from_pretrained(base_model, ) after training "
+ "has finished."
+ )
+ raise RuntimeError(msg) from exc
+ else:
+ raise
+ # Load_adapter has no return value present, modify it when appropriate.
+ from torch.nn.modules.module import _IncompatibleKeys
+
+ load_result = _IncompatibleKeys([], [])
+ else:
+ logger.warning(
+ "The intermediate checkpoints of PEFT may not be saved correctly, "
+ f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
+ "Check some examples here: https://github.com/huggingface/peft/issues/96"
+ )
+ has_been_loaded = False
+ else:
+ logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
+ has_been_loaded = False
+ else:
+ # We load the model state dict on the CPU to avoid an OOM error.
+ if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
+ state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
+ else:
+ check_torch_load_is_safe()
+ state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
+
+ # If the model is on the GPU, it still works!
+ # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
+ # which takes *args instead of **kwargs
+ load_result = model.load_state_dict(state_dict, False)
+ if not is_sagemaker_mp_enabled() and has_been_loaded:
+ self._issue_warnings_after_load(load_result)
+ elif os.path.exists(os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_INDEX_NAME)) or os.path.exists(
+ os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)
+ ):
+ load_result = load_sharded_checkpoint(
+ model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled()
+ )
+ if not is_sagemaker_mp_enabled():
+ self._issue_warnings_after_load(load_result)
+ else:
+ logger.warning(
+ f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
+ "on multiple nodes, you should activate `--save_on_each_node`."
+ )
+
+ def _issue_warnings_after_load(self, load_result):
+ if len(load_result.missing_keys) != 0:
+ if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
+ self.model._keys_to_ignore_on_save
+ ):
+ self.model.tie_weights()
+ else:
+ logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
+ if len(load_result.unexpected_keys) != 0:
+ logger.warning(
+ f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
+ )
+
+ def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
+ metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
+ self._report_to_hp_search(trial, self.state.global_step, metrics)
+
+ # Run delayed LR scheduler now that metrics are populated
+ if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) and not skip_scheduler:
+ metric_to_check = self.args.metric_for_best_model
+ if not metric_to_check.startswith("eval_"):
+ metric_to_check = f"eval_{metric_to_check}"
+ try:
+ self.lr_scheduler.step(metrics[metric_to_check])
+ except KeyError as exc:
+ raise KeyError(
+ f"The `metric_for_best_model` training argument is set to '{metric_to_check}', "
+ f"which is not found in the evaluation metrics. "
+ f"The available evaluation metrics are: {list(metrics.keys())}. "
+ f"Please ensure that the `compute_metrics` function returns a dictionary that includes '{metric_to_check}' or "
+ f"consider changing the `metric_for_best_model` via the TrainingArguments."
+ ) from exc
+ return metrics
+
+ def _maybe_log_save_evaluate(
+ self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=None
+ ):
+ if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
+ if is_torch_xla_available():
+ xm.mark_step()
+
+ logs: dict[str, float] = {}
+
+ # all_gather + mean() to get average loss over all processes
+ tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
+
+ # reset tr_loss to zero
+ tr_loss -= tr_loss
+
+ logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
+ if grad_norm is not None:
+ logs["grad_norm"] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
+ if learning_rate is not None:
+ logs["learning_rate"] = learning_rate
+ else:
+ logs["learning_rate"] = self._get_learning_rate()
+
+ self._total_loss_scalar += tr_loss_scalar
+ self._globalstep_last_logged = self.state.global_step
+ self.store_flos()
+
+ self.log(logs, start_time)
+
+ metrics = None
+ if self.control.should_evaluate:
+ metrics = self._evaluate(trial, ignore_keys_for_eval)
+ is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
+
+ if self.args.save_strategy == SaveStrategy.BEST:
+ self.control.should_save = is_new_best_metric
+
+ if self.control.should_save:
+ self._save_checkpoint(model, trial)
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
+
+ def _load_rng_state(self, checkpoint):
+ # Load RNG states from `checkpoint`
+ if checkpoint is None:
+ return
+
+ if self.args.world_size > 1:
+ process_index = self.args.process_index
+ rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
+ if not os.path.isfile(rng_file):
+ logger.info(
+ f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
+ "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
+ )
+ return
+ else:
+ rng_file = os.path.join(checkpoint, "rng_state.pth")
+ if not os.path.isfile(rng_file):
+ logger.info(
+ "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
+ "fashion, reproducibility is not guaranteed."
+ )
+ return
+
+ with safe_globals():
+ checkpoint_rng_state = torch.load(rng_file)
+ random.setstate(checkpoint_rng_state["python"])
+ np.random.set_state(checkpoint_rng_state["numpy"])
+ torch.random.set_rng_state(checkpoint_rng_state["cpu"])
+ if is_torch_xla_available():
+ xm.set_rng_state(checkpoint_rng_state["xla"])
+
+ is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED
+ if torch.cuda.is_available():
+ set_rng_state_for_device("CUDA", torch.cuda, checkpoint_rng_state, is_distributed)
+ if is_torch_npu_available():
+ set_rng_state_for_device("NPU", torch.npu, checkpoint_rng_state, is_distributed)
+ if is_torch_hpu_available():
+ set_rng_state_for_device("HPU", torch.hpu, checkpoint_rng_state, is_distributed)
+ if is_torch_mlu_available():
+ set_rng_state_for_device("MLU", torch.mlu, checkpoint_rng_state, is_distributed)
+ if is_torch_musa_available():
+ set_rng_state_for_device("MUSA", torch.musa, checkpoint_rng_state, is_distributed)
+
+ def _determine_best_metric(self, metrics, trial):
+ """
+ Determine if the model should be saved based on the evaluation metrics.
+
+ Returns:
+ bool: True if a new best metric was found, else False
+ """
+ is_new_best_metric = False
+
+ if self.args.metric_for_best_model is not None:
+ metric_to_check = self.args.metric_for_best_model
+
+ if not metric_to_check.startswith("eval_"):
+ metric_to_check = f"eval_{metric_to_check}"
+
+ try:
+ metric_value = metrics[metric_to_check]
+ except KeyError as exc:
+ raise KeyError(
+ f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
+ f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
+ ) from exc
+
+ operator = np.greater if self.args.greater_is_better else np.less
+
+ if self.state.best_metric is None:
+ self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")
+
+ if operator(metric_value, self.state.best_metric):
+ self.state.best_metric = metric_value
+
+ if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH]:
+ self.state.best_global_step = self.state.global_step
+
+ is_new_best_metric = True
+
+ return is_new_best_metric
+
+ def _save_checkpoint(self, model, trial):
+ # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
+ # want to save except FullyShardedDDP.
+ # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
+
+ # Save model checkpoint
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
+
+ if self.hp_search_backend is None and trial is None:
+ self.store_flos()
+
+ run_dir = self._get_output_dir(trial=trial)
+ output_dir = os.path.join(run_dir, checkpoint_folder)
+ self.save_model(output_dir, _internal_call=True)
+
+ if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH] and self.state.best_global_step:
+ # Wait for everyone to get here so we are sure the model has been saved by process 0
+ # before we check if the best_checkpoint_dir exists
+ if is_torch_xla_available():
+ xm.rendezvous("load_best_model_at_end")
+ elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+ dist.barrier()
+ elif is_sagemaker_mp_enabled():
+ smp.barrier()
+
+ best_checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.best_global_step}"
+ best_checkpoint_dir = os.path.join(run_dir, best_checkpoint_folder)
+
+ if os.path.exists(best_checkpoint_dir):
+ self.state.best_model_checkpoint = best_checkpoint_dir
+
+ if not self.args.save_only_model:
+ # Save optimizer and scheduler
+ self._save_optimizer_and_scheduler(output_dir)
+ self._save_scaler(output_dir)
+ # Save RNG state
+ self._save_rng_state(output_dir)
+
+ # Save the Trainer state
+ if self.args.should_save:
+ # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
+ for cb in [
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
+ ]:
+ cb_name = cb.__class__.__name__
+ cb_state = cb.state()
+ if isinstance(self.state.stateful_callbacks[cb_name], list):
+ self.state.stateful_callbacks[cb_name].append(cb_state)
+ else:
+ self.state.stateful_callbacks[cb_name] = cb_state
+ self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
+
+ if self.args.push_to_hub:
+ self._push_from_checkpoint(output_dir)
+
+ # Maybe delete some older checkpoints.
+ if self.args.should_save:
+ # we use mtime as default, filesystems without mtime support will be detected in `_sorted_checkpoints`
+ self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
+
+ def _save_rng_state(self, output_dir):
+ # Save RNG state in non-distributed training
+ rng_states = {
+ "python": random.getstate(),
+ "numpy": np.random.get_state(),
+ "cpu": torch.random.get_rng_state(),
+ }
+ if torch.cuda.is_available():
+ if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+ # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
+ rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
+ else:
+ rng_states["cuda"] = torch.cuda.random.get_rng_state()
+
+ if is_torch_xla_available():
+ rng_states["xla"] = xm.get_rng_state()
+
+ if is_torch_npu_available():
+ if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+ rng_states["npu"] = torch.npu.random.get_rng_state_all()
+ else:
+ rng_states["npu"] = torch.npu.random.get_rng_state()
+
+ if is_torch_hpu_available():
+ if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+ rng_states["hpu"] = torch.hpu.random.get_rng_state_all()
+ else:
+ rng_states["hpu"] = torch.hpu.random.get_rng_state()
+
+ if is_torch_mlu_available():
+ if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+ rng_states["mlu"] = torch.mlu.random.get_rng_state_all()
+ else:
+ rng_states["mlu"] = torch.mlu.random.get_rng_state()
+
+ if is_torch_musa_available():
+ if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+ rng_states["musa"] = torch.musa.get_rng_state_all()
+ else:
+ rng_states["musa"] = torch.musa.get_rng_state()
+
+ # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
+ # not yet exist.
+ os.makedirs(output_dir, exist_ok=True)
+
+ if self.args.world_size <= 1:
+ torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
+ else:
+ torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
+
+ def _save_optimizer_and_scheduler(self, output_dir):
+ if is_torch_xla_available():
+ xm.rendezvous("saving_optimizer_states")
+ if self.is_fsdp_xla_v1_enabled:
+ optm = {
+ "optimizer": self.optimizer.state_dict(),
+ "shard_metadata": self.model.get_shard_metadata(),
+ }
+ xm.save(
+ optm,
+ os.path.join(
+ output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
+ ),
+ master_only=False,
+ )
+ else:
+ xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
+ reissue_pt_warnings(caught_warnings)
+ elif is_sagemaker_mp_enabled():
+ opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
+ smp.barrier()
+ if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
+ smp.save(
+ opt_state_dict,
+ os.path.join(output_dir, OPTIMIZER_NAME),
+ partial=True,
+ v3=smp.state.cfg.shard_optimizer_state,
+ )
+ elif self.is_deepspeed_enabled:
+ # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
+ # config `stage3_gather_16bit_weights_on_model_save` is True
+ accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set(
+ inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys()
+ )
+ if accept_exclude_frozen_parameters and _is_peft_model(self.model):
+ self.model_wrapped.save_checkpoint(output_dir, exclude_frozen_parameters=True)
+ else:
+ self.model_wrapped.save_checkpoint(output_dir)
+ elif self.is_fsdp_enabled:
+ # save fsdp specific ckpt for resuming from ckpt
+ save_fsdp_model(
+ self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir, **_get_fsdp_ckpt_kwargs()
+ )
+ save_fsdp_optimizer(
+ self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
+ )
+ elif self.args.should_save:
+ # deepspeed.save_checkpoint above saves model/optim/sched
+ torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
+
+ # Save SCHEDULER & SCALER
+ is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance(
+ self.lr_scheduler, DeepSpeedSchedulerWrapper
+ )
+ if (
+ self.args.should_save
+ and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
+ and not is_torch_xla_available()
+ ):
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
+ reissue_pt_warnings(caught_warnings)
+
+ def _load_optimizer_and_scheduler(self, checkpoint):
+ """If optimizer and scheduler states exist, load them."""
+ if checkpoint is None:
+ return
+
+ if self.is_deepspeed_enabled:
+ # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
+ if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ check_torch_load_is_safe()
+ self.lr_scheduler.load_state_dict(
+ torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)
+ )
+ reissue_pt_warnings(caught_warnings)
+ return
+
+ checkpoint_file_exists = (
+ glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
+ if is_sagemaker_mp_enabled()
+ else (
+ os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
+ or os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME_BIN))
+ or (
+ os.path.isdir(checkpoint)
+ and any(
+ OPTIMIZER_NAME_BIN.split(".")[0] in folder_name
+ for folder_name in os.listdir(checkpoint)
+ if os.path.isdir(os.path.join(checkpoint, folder_name))
+ )
+ )
+ )
+ )
+ checkpoint_file_exists = (
+ glob.glob(os.path.join(checkpoint, f"rank*-of-{self.args.world_size}-{OPTIMIZER_NAME}"))
+ if self.is_fsdp_xla_v1_enabled
+ else checkpoint_file_exists
+ )
+ if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
+ # Load in optimizer and scheduler states
+ if is_torch_xla_available():
+ # On TPU we have to take some extra precautions to properly load the states on the right device.
+ if self.is_fsdp_xla_v1_enabled:
+ check_torch_load_is_safe()
+ optimizer_state = torch.load(
+ os.path.join(
+ checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
+ ),
+ map_location="cpu",
+ weights_only=True,
+ )
+ # We only need `optimizer` when resuming from checkpoint
+ optimizer_state = optimizer_state["optimizer"]
+ else:
+ check_torch_load_is_safe()
+ optimizer_state = torch.load(
+ os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu", weights_only=True
+ )
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ check_torch_load_is_safe()
+ lr_scheduler_state = torch.load(
+ os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu", weights_only=True
+ )
+ reissue_pt_warnings(caught_warnings)
+
+ xm.send_cpu_data_to_device(optimizer_state, self.args.device)
+ xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)
+
+ self.optimizer.load_state_dict(optimizer_state)
+ self.lr_scheduler.load_state_dict(lr_scheduler_state)
+ else:
+ if is_sagemaker_mp_enabled():
+ if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
+ # Optimizer checkpoint was saved with smp >= 1.10
+ def opt_load_hook(mod, opt):
+ opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
+
+ else:
+ # Optimizer checkpoint was saved with smp < 1.10
+ def opt_load_hook(mod, opt):
+ if IS_SAGEMAKER_MP_POST_1_10:
+ opt.load_state_dict(
+ smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)
+ )
+ else:
+ opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
+
+ self.model_wrapped.register_post_step_hook(opt_load_hook)
+ else:
+ # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.
+ # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
+ # likely to get OOM on CPU (since we load num_gpu times the optimizer state
+ map_location = self.args.device if self.args.world_size > 1 else "cpu"
+ if self.is_fsdp_enabled:
+ load_fsdp_optimizer(
+ self.accelerator.state.fsdp_plugin,
+ self.accelerator,
+ self.optimizer,
+ self.model,
+ checkpoint,
+ **_get_fsdp_ckpt_kwargs(),
+ )
+ else:
+ check_torch_load_is_safe()
+ self.optimizer.load_state_dict(
+ torch.load(
+ os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location, weights_only=True
+ )
+ )
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ check_torch_load_is_safe()
+ self.lr_scheduler.load_state_dict(
+ torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)
+ )
+ reissue_pt_warnings(caught_warnings)
+
+ def _save_scaler(self, output_dir):
+ # See if there is a scaler attribute
+ try:
+ scaler = self.accelerator.scaler
+ except AttributeError:
+ return
+ if scaler is None:
+ return
+ if is_torch_xla_available():
+ xm.rendezvous("saving_scaler_state")
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ xm.save(self.accelerator.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
+ reissue_pt_warnings(caught_warnings)
+
+ # Save SCALER
+ if self.args.should_save and not is_torch_xla_available():
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ torch.save(self.accelerator.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
+ reissue_pt_warnings(caught_warnings)
+
+ def _load_scaler(self, checkpoint):
+ """If scaler state exists, load it."""
+ if checkpoint is None:
+ return
+
+ checkpoint_file_exists = os.path.isfile(os.path.join(checkpoint, SCALER_NAME))
+
+ if checkpoint_file_exists:
+ # On TPU we have to take some extra precautions to properly load the states on the right device.
+ # Load in scaler states
+ if is_torch_xla_available():
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ check_torch_load_is_safe()
+ scaler_state = torch.load(
+ os.path.join(checkpoint, SCALER_NAME), map_location="cpu", weights_only=True
+ )
+ reissue_pt_warnings(caught_warnings)
+ xm.send_cpu_data_to_device(scaler_state, self.args.device)
+ self.accelerator.scaler.load_state_dict(scaler_state)
+ else:
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ check_torch_load_is_safe()
+ self.accelerator.scaler.load_state_dict(
+ torch.load(os.path.join(checkpoint, SCALER_NAME), weights_only=True)
+ )
+ reissue_pt_warnings(caught_warnings)
+
+ def _load_callback_state(self):
+ """If callback states exist and were passed in, restore their states if enabled"""
+ if not self.args.restore_callback_states_from_checkpoint:
+ return
+ # Callback states are stored in stateful_callbacks
+ not_found = []
+ new_callbacks = []
+ original_callbacks = self.callback_handler.callbacks + [self.control]
+ for stored_callback, data in self.state.stateful_callbacks.items():
+ if not isinstance(data, list):
+ data = [data]
+ if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks):
+ # We can load/restore from multiple callbacks of the same type.
+ duplicates = [
+ callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback
+ ]
+ for callback, callback_data in zip(duplicates, data):
+ args = callback_data.get("args", {})
+ attributes = callback_data.get("attributes", {})
+ new_callback = type(callback)(**args)
+ for attribute, value in attributes.items():
+ setattr(new_callback, attribute, value)
+ if isinstance(callback, TrainerControl):
+ # Specifically for restoring the `control` state
+ self.control = new_callback
+ else:
+ new_callbacks.append(new_callback)
+ # We remove the existing callback and add it to the list of new callbacks
+ self.callback_handler.remove_callback(type(new_callback))
+ logger.info("Continuing training from checkpoint, restoring any callbacks that were passed in")
+ else:
+ not_found.append(stored_callback)
+ if len(not_found) > 0:
+ logger.warning(
+ f"Checkpoint included callbacks not included in current configuration. Ignoring. ({', '.join(not_found)})"
+ )
+ for callback in new_callbacks:
+ self.callback_handler.add_callback(callback)
+
+ def hyperparameter_search(
+ self,
+ hp_space: Optional[Callable[["optuna.Trial"], dict[str, float]]] = None,
+ compute_objective: Optional[Callable[[dict[str, float]], float]] = None,
+ n_trials: int = 20,
+ direction: Union[str, list[str]] = "minimize",
+ backend: Optional[Union["str", HPSearchBackend]] = None,
+ hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
+ **kwargs,
+ ) -> Union[BestRun, list[BestRun]]:
+ """
+ Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined
+ by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
+ the sum of all metrics otherwise.
+
+
+
+ To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to
+ reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to
+ subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom
+ optimizer/scheduler.
+
+
+
+ Args:
+ hp_space (`Callable[["optuna.Trial"], dict[str, float]]`, *optional*):
+ A function that defines the hyperparameter search space. Will default to
+ [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or
+ [`~trainer_utils.default_hp_space_sigopt`] depending on your backend.
+ compute_objective (`Callable[[dict[str, float]], float]`, *optional*):
+ A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`
+ method. Will default to [`~trainer_utils.default_compute_objective`].
+ n_trials (`int`, *optional*, defaults to 100):
+ The number of trial runs to test.
+ direction (`str` or `list[str]`, *optional*, defaults to `"minimize"`):
+ If it's single objective optimization, direction is `str`, can be `"minimize"` or `"maximize"`, you
+ should pick `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or
+ several metrics. If it's multi objectives optimization, direction is `list[str]`, can be List of
+ `"minimize"` and `"maximize"`, you should pick `"minimize"` when optimizing the validation loss,
+ `"maximize"` when optimizing one or several metrics.
+ backend (`str` or [`~training_utils.HPSearchBackend`], *optional*):
+ The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
+ on which one is installed. If all are installed, will default to optuna.
+ hp_name (`Callable[["optuna.Trial"], str]]`, *optional*):
+ A function that defines the trial/run name. Will default to None.
+ kwargs (`dict[str, Any]`, *optional*):
+ Additional keyword arguments for each backend:
+
+ - `optuna`: parameters from
+ [optuna.study.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)
+ and also the parameters `timeout`, `n_jobs` and `gc_after_trial` from
+ [optuna.study.Study.optimize](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.optimize)
+ - `ray`: parameters from [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run).
+ If `resources_per_trial` is not set in the `kwargs`, it defaults to 1 CPU core and 1 GPU (if available).
+ If `progress_reporter` is not set in the `kwargs`,
+ [ray.tune.CLIReporter](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.CLIReporter.html) is used.
+ - `sigopt`: the parameter `proxies` from
+ [sigopt.Connection.set_proxies](https://docs.sigopt.com/support/faq#how-do-i-use-sigopt-with-a-proxy).
+
+ Returns:
+ [`trainer_utils.BestRun` or `list[trainer_utils.BestRun]`]: All the information about the best run or best
+ runs for multi-objective optimization. Experiment summary can be found in `run_summary` attribute for Ray
+ backend.
+ """
+ if backend is None:
+ backend = default_hp_search_backend()
+ backend = HPSearchBackend(backend)
+ backend_obj = ALL_HYPERPARAMETER_SEARCH_BACKENDS[backend]()
+ backend_obj.ensure_available()
+ self.hp_search_backend = backend
+ if self.model_init is None:
+ raise RuntimeError(
+ "To use hyperparameter search, you need to pass your model through a model_init function."
+ )
+
+ self.hp_space = backend_obj.default_hp_space if hp_space is None else hp_space
+ self.hp_name = hp_name
+ self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
+
+ best_run = backend_obj.run(self, n_trials, direction, **kwargs)
+
+ self.hp_search_backend = None
+ return best_run
+
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
+ """
+ Log `logs` on the various objects watching training.
+
+ Subclass and override this method to inject custom behavior.
+
+ Args:
+ logs (`dict[str, float]`):
+ The values to log.
+ start_time (`Optional[float]`):
+ The start of training.
+ """
+ if self.state.epoch is not None:
+ logs["epoch"] = self.state.epoch
+ if self.args.include_num_input_tokens_seen != "no":
+ logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen
+ if start_time is not None:
+ logs.update(speed_metrics("train", start_time, num_tokens=self.state.num_input_tokens_seen))
+
+ output = {**logs, **{"step": self.state.global_step}}
+ self.state.log_history.append(output)
+ self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
+
+ def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
+ """
+ Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
+ """
+ if isinstance(data, Mapping):
+ return type(data)({k: self._prepare_input(v) for k, v in data.items()})
+ elif isinstance(data, (tuple, list)):
+ return type(data)(self._prepare_input(v) for v in data)
+ elif isinstance(data, torch.Tensor):
+ kwargs = {"device": self.args.device}
+ if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
+ # NLP models inputs are int/uint and those get adjusted to the right dtype of the
+ # embedding. Other models such as wav2vec2's inputs are already float and thus
+ # may need special handling to match the dtypes of the model
+ kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
+ return data.to(**kwargs)
+ return data
+
+ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
+ """
+ Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
+ handling potential state.
+ """
+ inputs = self._prepare_input(inputs)
+ if len(inputs) == 0:
+ raise ValueError(
+ "The batch received was empty, your model won't be able to train on it. Double-check that your "
+ f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
+ )
+ if self.args.past_index >= 0 and self._past is not None:
+ inputs["mems"] = self._past
+
+ return inputs
+
+ def _is_attention_mask_causal(self, attention_mask):
+ """
+ Check if an attention mask is causal (compatible with causal attention).
+ Context parallelism only supports causal attention patterns. This function
+ checks if the provided attention mask is compatible.
+
+ Args:
+ attention_mask (torch.Tensor): The attention mask to check
+
+ Returns:
+ bool: True if the mask is causal or compatible with causal attention
+ """
+ if attention_mask is None:
+ return True # No mask is considered causal (model uses default causal masking)
+
+ # Handle different mask dimensions
+ if attention_mask.dim() == 2:
+ # (batch_size, seq_len) - standard padding mask, compatible with causal attention
+ return True
+ elif attention_mask.dim() in [3, 4]:
+ # (batch_size, seq_len, seq_len) or (batch_size, num_heads, seq_len, seq_len)
+ # Check if it's lower triangular (causal)
+ seq_len = attention_mask.shape[-1]
+ if seq_len <= 1:
+ return True # Single token or empty is always causal
+
+ # Take first batch and head (if 4D) for checking pattern
+ if attention_mask.dim() == 4:
+ mask = attention_mask[0, 0] # First batch, first head
+ else:
+ mask = attention_mask[0] # First batch
+
+ # Check if upper triangular part is masked (should be 0 or very negative for causal)
+ upper_triangular = torch.triu(mask, diagonal=1)
+
+ # For causal masks, upper triangular should be 0 or very negative (like -inf)
+ # Use a reasonable threshold to handle float precision issues
+ is_causal = torch.all(upper_triangular <= 1e-6) or torch.all(upper_triangular < -1e4)
+ return is_causal.item() if isinstance(is_causal, torch.Tensor) else is_causal
+
+ # For unknown dimensions, be conservative and reject
+ return False
+
+ def _prepare_context_parallel_inputs(self, model, inputs: dict[str, Union[torch.Tensor, Any]]):
+ """
+ Prepare inputs for context parallelism by setting up buffers and validation.
+
+ Args:
+ model: The model being trained
+ inputs: Input tensors to prepare
+
+ Returns:
+ tuple: (context_manager, prepared_inputs) where context_manager is either
+ the context parallelism wrapper or a no-op context
+ """
+ if (
+ getattr(self.accelerator, "parallelism_config", None) is not None
+ and self.accelerator.parallelism_config.cp_enabled
+ ):
+ if hasattr(model, "config"):
+ if model.config._attn_implementation != "sdpa":
+ raise ValueError(
+ f"Context parallelism is supported only with SDPA attention, you are using {model.config._attn_implementation}."
+ )
+
+ if "position_ids" not in inputs:
+ logger.warning_once("Position IDs not found in the inputs, generating manually")
+ inputs["position_ids"] = torch.arange(
+ inputs["input_ids"].size(1), device=inputs["input_ids"].device
+ ).expand(inputs["input_ids"].size(0), -1)
+ if "shift_labels" not in inputs:
+ logger.warning_once("Shift labels not found in the inputs, shifting manually")
+ if "labels" in inputs:
+ _ignore_index = -100
+ labels = nn.functional.pad(inputs["labels"], (0, 1), value=_ignore_index)
+ inputs["shift_labels"] = labels[:, 1:].contiguous()
+
+ buffers = []
+ buffer_seq_dims = []
+
+ if "input_ids" in inputs:
+ buffers.append(inputs["input_ids"])
+ buffer_seq_dims.append(1) # Sequence dimension
+ if "labels" in inputs:
+ buffers.append(inputs["labels"])
+ buffer_seq_dims.append(1)
+ if "shift_labels" in inputs:
+ buffers.append(inputs["shift_labels"])
+ buffer_seq_dims.append(1)
+ # Add attention_mask to buffers for context parallel splitting (only if causal)
+ if "attention_mask" in inputs:
+ # Only validate causal mask once for performance
+ if not getattr(self, "_attn_mask_causal_checked", False):
+ # Context parallel currently doesn't support other masks than causal
+ # Accelerate applies hooks to replace mask with is_causal arg in SDPA
+ # Check if the mask is really causal and if not throw an error
+ attention_mask = inputs["attention_mask"]
+ if not self._is_attention_mask_causal(attention_mask):
+ raise ValueError(
+ "Context parallelism only supports causal attention masks. "
+ "The provided attention_mask is not causal. "
+ "Please ensure your data uses causal masking (lower triangular) "
+ "or remove the attention_mask to use the model's default causal masking."
+ )
+ self._attn_mask_causal_checked = True
+ if self._attn_mask_causal_checked:
+ # Add to buffers only after validation (or if validation already passed)
+ attention_mask = inputs["attention_mask"]
+ if attention_mask.dim() == 2:
+ buffers.append(attention_mask)
+ buffer_seq_dims.append(1)
+ else:
+ # Other dimensionality; keep as-is without sharding to avoid incorrect splits
+ pass
+ # Include position_ids in context parallelism splitting
+ if "position_ids" in inputs and inputs["position_ids"] is not None:
+ buffers.append(inputs["position_ids"])
+ buffer_seq_dims.append(1)
+
+ return partial(
+ self.accelerator.maybe_context_parallel,
+ buffers=buffers,
+ buffer_seq_dims=buffer_seq_dims,
+ no_restore_buffers=set(buffers),
+ ), inputs
+
+ return contextlib.nullcontext, inputs
+
+ def compute_loss_context_manager(self):
+ """
+ A helper wrapper to group together context managers.
+ """
+ ctx_stack = contextlib.ExitStack()
+
+ autocast_ctx = self.autocast_smart_context_manager()
+ if not isinstance(autocast_ctx, contextlib.nullcontext):
+ ctx_stack.enter_context(autocast_ctx)
+
+ return ctx_stack
+
+ def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
+ """
+ A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
+ arguments, depending on the situation.
+ """
+ if self.use_cpu_amp:
+ ctx_manager = torch.autocast(device_type="cpu", cache_enabled=cache_enabled, dtype=self.amp_dtype)
+ else:
+ ctx_manager = contextlib.nullcontext()
+
+ return ctx_manager
+
+ def training_step(
+ self,
+ model: nn.Module,
+ inputs: dict[str, Union[torch.Tensor, Any]],
+ num_items_in_batch: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Perform a training step on a batch of inputs.
+
+ Subclass and override to inject custom behavior.
+
+ Args:
+ model (`nn.Module`):
+ The model to train.
+ inputs (`dict[str, Union[torch.Tensor, Any]]`):
+ The inputs and targets of the model.
+
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
+ argument `labels`. Check your model's documentation for all accepted arguments.
+
+ Return:
+ `torch.Tensor`: The tensor with training loss on this batch.
+ """
+ # Prepare buffers for context parallelism
+
+ cp_context, inputs = self._prepare_context_parallel_inputs(model, inputs)
+
+ # Context manager is no-op if CP isn't enabled
+ with cp_context():
+ model.train()
+ if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
+ self.optimizer.train()
+
+ inputs = self._prepare_inputs(inputs)
+ if is_sagemaker_mp_enabled():
+ loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
+ return loss_mb.reduce_mean().detach().to(self.args.device)
+
+ with self.compute_loss_context_manager():
+ loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
+
+ del inputs
+ if (
+ self.args.torch_empty_cache_steps is not None
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
+ ):
+ if is_torch_xpu_available():
+ torch.xpu.empty_cache()
+ elif is_torch_mlu_available():
+ torch.mlu.empty_cache()
+ elif is_torch_musa_available():
+ torch.musa.empty_cache()
+ elif is_torch_npu_available():
+ torch.npu.empty_cache()
+ elif is_torch_mps_available():
+ torch.mps.empty_cache()
+ elif is_torch_hpu_available():
+ logger.warning(
+ "`torch_empty_cache_steps` is set but HPU device/backend does not support empty_cache()."
+ )
+ else:
+ torch.cuda.empty_cache()
+
+ kwargs = {}
+
+ # For LOMO optimizers you need to explicitly use the learning rate
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
+ kwargs["learning_rate"] = self._get_learning_rate()
+
+ if self.args.n_gpu > 1:
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
+
+ if self.use_apex:
+ from apex import amp
+
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
+ scaled_loss.backward()
+ else:
+ # Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
+ if (
+ not self.model_accepts_loss_kwargs or num_items_in_batch is None
+ ) and self.compute_loss_func is None:
+ # If the model does not accept loss kwargs, we need to normalize the loss by the number of gradient accumulation steps
+ loss = loss / self.current_gradient_accumulation_steps
+
+ # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
+ # https://github.com/huggingface/transformers/pull/35808
+ if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
+ kwargs["scale_wrt_gas"] = False
+
+ self.accelerator.backward(loss, **kwargs)
+
+ return loss.detach()
+
+ def compute_loss(
+ self,
+ model: nn.Module,
+ inputs: dict[str, Union[torch.Tensor, Any]],
+ return_outputs: bool = False,
+ num_items_in_batch: Optional[torch.Tensor] = None,
+ ):
+ """
+ How the loss is computed by Trainer. By default, all models return the loss in the first element.
+
+ Args:
+ model (`nn.Module`):
+ The model to compute the loss for.
+ inputs (`dict[str, Union[torch.Tensor, Any]]`):
+ The input data for the model.
+ return_outputs (`bool`, *optional*, defaults to `False`):
+ Whether to return the model outputs along with the loss.
+ num_items_in_batch (Optional[torch.Tensor], *optional*):
+ The number of items in the batch. If num_items_in_batch is not passed,
+
+ Returns:
+ The loss of the model along with its output if return_outputs was set to True
+
+ Subclass and override for custom behavior. If you are not using `num_items_in_batch` when computing your loss,
+ make sure to overwrite `self.model_accepts_loss_kwargs` to `False`. Otherwise, the loss calculating might be slightly inaccurate when performing gradient accumulation.
+ """
+ if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
+ labels = inputs.pop("labels")
+ else:
+ labels = None
+ if self.model_accepts_loss_kwargs:
+ kwargs = {}
+ if num_items_in_batch is not None:
+ kwargs["num_items_in_batch"] = num_items_in_batch
+ inputs = {**inputs, **kwargs}
+ outputs = model(**inputs)
+ # Save past state if it exists
+ # TODO: this needs to be fixed and made cleaner later.
+ if self.args.past_index >= 0:
+ self._past = outputs[self.args.past_index]
+
+ # User-defined compute_loss function
+ if self.compute_loss_func is not None:
+ if labels is None:
+ logger.warning(
+ "Trainer: `compute_loss_func` is defined but `labels=None`. "
+ "Your custom loss function will still be called with labels=None. "
+ )
+ loss = self.compute_loss_func(
+ outputs,
+ labels,
+ num_items_in_batch=num_items_in_batch,
+ )
+ # Default HF loss handling (label smoothing) if no custom loss function
+ elif labels is not None:
+ unwrapped_model = self.accelerator.unwrap_model(model)
+ model_name = (
+ unwrapped_model.base_model.model._get_name()
+ if _is_peft_model(unwrapped_model)
+ else unwrapped_model._get_name()
+ )
+ if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
+ loss = self.label_smoother(outputs, labels, shift_labels=True)
+ else:
+ loss = self.label_smoother(outputs, labels)
+ else:
+ if isinstance(outputs, dict) and "loss" not in outputs:
+ raise ValueError(
+ "The model did not return a loss from the inputs, only the following keys: "
+ f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
+ )
+ # We don't use .loss here since the model may return tuples instead of ModelOutput.
+ loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
+
+ if (
+ self.args.average_tokens_across_devices
+ and (self.model_accepts_loss_kwargs or self.compute_loss_func)
+ and num_items_in_batch is not None
+ ):
+ loss *= self.accelerator.num_processes if self.args.n_gpu <= 1 else self.args.n_gpu
+
+ return (loss, outputs) if return_outputs else loss
+
+ def is_local_process_zero(self) -> bool:
+ """
+ Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
+ machines) main process.
+ """
+ return self.args.local_process_index == 0
+
+ def is_world_process_zero(self) -> bool:
+ """
+ Whether or not this process is the global main process (when training in a distributed fashion on several
+ machines, this is only going to be `True` for one process).
+ """
+ # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
+ # process index.
+ if is_sagemaker_mp_enabled():
+ return smp.rank() == 0
+ else:
+ return self.args.process_index == 0
+
+ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
+ """
+ Will save the model, so you can reload it using `from_pretrained()`.
+
+ Will only save from the main process.
+ """
+
+ if output_dir is None:
+ output_dir = self.args.output_dir
+
+ if is_torch_xla_available():
+ self._save_tpu(output_dir)
+ elif is_sagemaker_mp_enabled():
+ # Calling the state_dict needs to be done on the wrapped model and on all processes.
+ os.makedirs(output_dir, exist_ok=True)
+ state_dict = self.model_wrapped.state_dict()
+ if self.args.should_save:
+ self._save(output_dir, state_dict=state_dict)
+ if IS_SAGEMAKER_MP_POST_1_10:
+ # 'user_content.pt' indicates model state_dict saved with smp >= 1.10
+ Path(os.path.join(output_dir, "user_content.pt")).touch()
+ # We are in N-D parallelism if we have parallelism_config set, so we check accelerate if we're on a to_save rank
+ elif getattr(self.accelerator, "parallelism_config", None) is not None:
+ if self.accelerator.should_save_model:
+ self._save(output_dir)
+ # If we drop to here, we're in 1D parallelism, so all ranks need to go to `save_pretrained`
+ elif (tp_size := getattr(self.model, "_tp_size", 0)) is not None and tp_size > 1:
+ self._save(output_dir)
+ elif self.is_fsdp_enabled:
+ if "FULL_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type):
+ state_dict = self.accelerator.get_state_dict(self.model)
+ if self.args.should_save:
+ self._save(output_dir, state_dict=state_dict)
+ elif self.is_deepspeed_enabled:
+ try:
+ state_dict = self.accelerator.get_state_dict(self.deepspeed)
+ if self.args.should_save:
+ self._save(output_dir, state_dict=state_dict)
+ except ValueError:
+ logger.warning(
+ " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
+ " zero_to_fp32.py to recover weights"
+ )
+ if self.args.should_save:
+ self._save(output_dir, state_dict={})
+ # remove the dummy state_dict
+ remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
+ self.model_wrapped.save_checkpoint(output_dir)
+
+ elif self.args.should_save:
+ self._save(output_dir)
+
+ # Push to the Hub when `save_model` is called by the user.
+ if self.args.push_to_hub and not _internal_call:
+ self.push_to_hub(commit_message="Model save", revision=self.args.hub_revision)
+
+ def _save_tpu(self, output_dir: Optional[str] = None):
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
+
+ logger.info(f"Saving model checkpoint to {output_dir}")
+ model = self.model
+ xm.mark_step()
+
+ if xm.is_master_ordinal(local=False):
+ os.makedirs(output_dir, exist_ok=True)
+ torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
+
+ # Save a trained model and configuration using `save_pretrained()`.
+ # They can then be reloaded using `from_pretrained()`
+ supported_classes = (PushToHubMixin,)
+ xm.rendezvous("saving_checkpoint")
+ if self.is_fsdp_xla_v1_enabled:
+ ckpt = {
+ "model": model.state_dict(),
+ "shard_metadata": model.get_shard_metadata(),
+ }
+ ckpt_path = os.path.join(
+ output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{WEIGHTS_NAME}"
+ )
+ # All ranks save sharded checkpoint
+ xm.save(ckpt, ckpt_path, master_only=False)
+ # Make sure all ranks have saved checkpoints
+ xm.rendezvous("save_full_checkpoints")
+ # Master save full checkpoint
+ if self.args.should_save:
+ from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints
+
+ full_state_dict, _ = consolidate_sharded_model_checkpoints(
+ ckpt_prefix=os.path.join(output_dir, ""),
+ ckpt_suffix=f"rank*-of-*-{WEIGHTS_NAME}",
+ save_model=False,
+ )
+ model = model.module.module
+ unwrapped_model = self.accelerator.unwrap_model(model)
+ if isinstance(unwrapped_model, supported_classes):
+ unwrapped_model.save_pretrained(
+ output_dir,
+ state_dict=full_state_dict,
+ save_function=xm.save,
+ safe_serialization=self.args.save_safetensors,
+ )
+ else:
+ logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
+ xm.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME))
+ elif not isinstance(model, supported_classes):
+ if isinstance(self.accelerator.unwrap_model(model), supported_classes):
+ self.accelerator.unwrap_model(model).save_pretrained(
+ output_dir,
+ is_main_process=self.args.should_save,
+ state_dict=xm._maybe_convert_to_cpu(model.state_dict()),
+ save_function=xm.save,
+ safe_serialization=self.args.save_safetensors,
+ )
+ else:
+ logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
+ state_dict = xm._maybe_convert_to_cpu(model.state_dict())
+ xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
+ else:
+ model.save_pretrained(
+ output_dir,
+ is_main_process=self.args.should_save,
+ save_function=xm.save,
+ safe_serialization=self.args.save_safetensors,
+ state_dict=xm._maybe_convert_to_cpu(model.state_dict()),
+ )
+ if self.processing_class is not None and self.args.should_save:
+ self.processing_class.save_pretrained(output_dir)
+
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
+ # If we are executing this function, we are the process zero, so we don't check for that.
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
+ os.makedirs(output_dir, exist_ok=True)
+ logger.info(f"Saving model checkpoint to {output_dir}")
+
+ supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
+ # Save a trained model and configuration using `save_pretrained()`.
+ # They can then be reloaded using `from_pretrained()`
+ if not isinstance(self.model, supported_classes):
+ if state_dict is None:
+ state_dict = self.model.state_dict()
+
+ if isinstance(self.accelerator.unwrap_model(self.model, keep_torch_compile=False), supported_classes):
+ self.accelerator.unwrap_model(self.model, keep_torch_compile=False).save_pretrained(
+ output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
+ )
+ else:
+ logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
+ if self.args.save_safetensors:
+ safetensors.torch.save_file(
+ state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
+ )
+ else:
+ torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
+ else:
+ self.model.save_pretrained(
+ output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
+ )
+
+ if self.processing_class is not None:
+ self.processing_class.save_pretrained(output_dir)
+ elif (
+ self.data_collator is not None
+ and hasattr(self.data_collator, "tokenizer")
+ and self.data_collator.tokenizer is not None
+ ):
+ logger.info("Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`")
+ self.data_collator.tokenizer.save_pretrained(output_dir)
+
+ # Good practice: save your training arguments together with the trained model
+ torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
+
+ def store_flos(self):
+ # Storing the number of floating-point operations that went into the model
+ if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+ self.state.total_flos += (
+ distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
+ )
+ self.current_flos = 0
+ else:
+ self.state.total_flos += self.current_flos
+ self.current_flos = 0
+
+ def _sorted_checkpoints(
+ self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
+ ) -> list[str]:
+ ordering_and_checkpoint_path = []
+
+ glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
+
+ for path in glob_checkpoints:
+ if use_mtime:
+ ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
+ else:
+ regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
+ if regex_match is not None and regex_match.groups() is not None:
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
+
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
+ # mtime is not reliable on all filesystems, especially on some fuse fs in cloud environments
+ # so we check if the mtime is fake and fallback to numerical ordering if needed
+ if use_mtime and len(ordering_and_checkpoint_path) > 1:
+ mtime_diff = checkpoints_sorted[-1][0] - checkpoints_sorted[0][0]
+ if mtime_diff < 1.0: # less than 1 second, which is almost impossible when mtime works fine
+ warnings.warn("mtime may not be reliable on this filesystem, falling back to numerical ordering")
+ return self._sorted_checkpoints(
+ use_mtime=False, output_dir=output_dir, checkpoint_prefix=checkpoint_prefix
+ )
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
+
+ # Make sure we don't delete the best model.
+ if (
+ self.state.best_model_checkpoint is not None
+ and str(Path(self.state.best_model_checkpoint)) in checkpoints_sorted
+ ):
+ best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
+ for i in range(best_model_index, len(checkpoints_sorted) - 2):
+ checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
+ return checkpoints_sorted
+
+ def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
+ if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
+ return
+
+ # Check if we should delete older checkpoint(s)
+ checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
+ if len(checkpoints_sorted) <= self.args.save_total_limit:
+ return
+
+ # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
+ # we don't do to allow resuming.
+ save_total_limit = self.args.save_total_limit
+ if (
+ self.state.best_model_checkpoint is not None
+ and self.args.save_total_limit == 1
+ and checkpoints_sorted[-1] != self.state.best_model_checkpoint
+ ):
+ save_total_limit = 2
+
+ number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
+ checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
+ for checkpoint in checkpoints_to_be_deleted:
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
+ shutil.rmtree(checkpoint, ignore_errors=True)
+
+ def evaluate(
+ self,
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
+ ignore_keys: Optional[list[str]] = None,
+ metric_key_prefix: str = "eval",
+ ) -> dict[str, float]:
+ """
+ Run evaluation and returns metrics.
+
+ The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
+ (pass it to the init `compute_metrics` argument).
+
+ You can also subclass and override this method to inject custom behavior.
+
+ Args:
+ eval_dataset (Union[`Dataset`, dict[str, `Dataset`]), *optional*):
+ Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
+ not accepted by the `model.forward()` method are automatically removed. If it is a dictionary, it will
+ evaluate on each dataset, prepending the dictionary key to the metric name. Datasets must implement the
+ `__len__` method.
+
+
+
+ If you pass a dictionary with names of datasets as keys and datasets as values, evaluate will run
+ separate evaluations on each dataset. This can be useful to monitor how training affects other
+ datasets or simply to get a more fine-grained evaluation.
+ When used with `load_best_model_at_end`, make sure `metric_for_best_model` references exactly one
+ of the datasets. If you, for example, pass in `{"data1": data1, "data2": data2}` for two datasets
+ `data1` and `data2`, you could specify `metric_for_best_model="eval_data1_loss"` for using the
+ loss on `data1` and `metric_for_best_model="eval_data2_loss"` for the loss on `data2`.
+
+
+
+ ignore_keys (`list[str]`, *optional*):
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+ gathering predictions.
+ metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
+ An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
+ "eval_bleu" if the prefix is "eval" (default)
+
+ Returns:
+ A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
+ dictionary also contains the epoch number which comes from the training state.
+ """
+ # handle multiple eval datasets
+ override = eval_dataset is not None
+ eval_dataset = eval_dataset if override else self.eval_dataset
+ if isinstance(eval_dataset, dict):
+ metrics = {}
+ for eval_dataset_name, _eval_dataset in eval_dataset.items():
+ dataset_metrics = self.evaluate(
+ eval_dataset=_eval_dataset if override else eval_dataset_name,
+ ignore_keys=ignore_keys,
+ metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}",
+ )
+ metrics.update(dataset_metrics)
+ return metrics
+
+ # memory metrics - must set up as early as possible
+ self._memory_tracker.start()
+
+ eval_dataloader = self.get_eval_dataloader(eval_dataset)
+ if self.is_fsdp_xla_v2_enabled:
+ eval_dataloader = tpu_spmd_dataloader(eval_dataloader)
+
+ start_time = time.time()
+
+ eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
+ output = eval_loop(
+ eval_dataloader,
+ description="Evaluation",
+ # No point gathering the predictions if there are no metrics, otherwise we defer to
+ # self.args.prediction_loss_only
+ prediction_loss_only=True if self.compute_metrics is None else None,
+ ignore_keys=ignore_keys,
+ metric_key_prefix=metric_key_prefix,
+ )
+
+ total_batch_size = self.args.eval_batch_size * self.args.world_size
+ if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
+ start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
+ if f"{metric_key_prefix}_model_preparation_time" in output.metrics:
+ start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"]
+ output.metrics.update(
+ speed_metrics(
+ metric_key_prefix,
+ start_time,
+ num_samples=output.num_samples,
+ num_steps=math.ceil(output.num_samples / total_batch_size),
+ )
+ )
+
+ self.log(output.metrics)
+
+ if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
+ # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
+ xm.master_print(met.metrics_report())
+
+ self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
+
+ self._memory_tracker.stop_and_update_metrics(output.metrics)
+
+ return output.metrics
+
+ def predict(
+ self, test_dataset: Dataset, ignore_keys: Optional[list[str]] = None, metric_key_prefix: str = "test"
+ ) -> PredictionOutput:
+ """
+ Run prediction and returns predictions and potential metrics.
+
+ Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
+ will also return metrics, like in `evaluate()`.
+
+ Args:
+ test_dataset (`Dataset`):
+ Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the
+ `model.forward()` method are automatically removed. Has to implement the method `__len__`
+ ignore_keys (`list[str]`, *optional*):
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+ gathering predictions.
+ metric_key_prefix (`str`, *optional*, defaults to `"test"`):
+ An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
+ "test_bleu" if the prefix is "test" (default)
+
+
+
+ If your predictions or labels have different sequence length (for instance because you're doing dynamic padding
+ in a token classification task) the predictions will be padded (on the right) to allow for concatenation into
+ one array. The padding index is -100.
+
+
+
+ Returns: *NamedTuple* A namedtuple with the following keys:
+
+ - predictions (`np.ndarray`): The predictions on `test_dataset`.
+ - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
+ - metrics (`dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
+ labels).
+ """
+ # memory metrics - must set up as early as possible
+ self._memory_tracker.start()
+
+ test_dataloader = self.get_test_dataloader(test_dataset)
+ start_time = time.time()
+
+ eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
+ output = eval_loop(
+ test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
+ )
+ total_batch_size = self.args.eval_batch_size * self.args.world_size
+ if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
+ start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
+ if f"{metric_key_prefix}_model_preparation_time" in output.metrics:
+ start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"]
+ output.metrics.update(
+ speed_metrics(
+ metric_key_prefix,
+ start_time,
+ num_samples=output.num_samples,
+ num_steps=math.ceil(output.num_samples / total_batch_size),
+ )
+ )
+
+ self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
+ self._memory_tracker.stop_and_update_metrics(output.metrics)
+
+ return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
+
+ def evaluation_loop(
+ self,
+ dataloader: DataLoader,
+ description: str,
+ prediction_loss_only: Optional[bool] = None,
+ ignore_keys: Optional[list[str]] = None,
+ metric_key_prefix: str = "eval",
+ ) -> EvalLoopOutput:
+ """
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
+
+ Works both with or without labels.
+ """
+ args = self.args
+
+ prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
+
+ # if eval is called w/o train, handle model prep here
+ if self.is_deepspeed_enabled and self.deepspeed is None:
+ _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
+
+ model = self._wrap_model(self.model, training=False, dataloader=dataloader)
+
+ if len(self.accelerator._models) == 0 and model is self.model:
+ start_time = time.time()
+ model = (
+ self.accelerator.prepare(model)
+ if self.is_deepspeed_enabled
+ or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8" and not self.args.torch_compile)
+ else self.accelerator.prepare_model(model, evaluation_mode=True)
+ )
+ self.model_preparation_time = round(time.time() - start_time, 4)
+
+ if self.is_fsdp_enabled:
+ self.model = model
+
+ # for the rest of this function `model` is the outside model, whether it was wrapped or not
+ if model is not self.model:
+ self.model_wrapped = model
+
+ # backward compatibility
+ if self.is_deepspeed_enabled:
+ self.deepspeed = self.model_wrapped
+
+ # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
+ # while ``train`` is running, cast it to the right dtype first and then put on device
+ if not self.is_in_train:
+ if args.fp16_full_eval:
+ model = model.to(dtype=torch.float16, device=args.device)
+ elif args.bf16_full_eval:
+ model = model.to(dtype=torch.bfloat16, device=args.device)
+
+ batch_size = self.args.eval_batch_size
+
+ logger.info(f"\n***** Running {description} *****")
+ if has_length(dataloader):
+ logger.info(f" Num examples = {self.num_examples(dataloader)}")
+ else:
+ logger.info(" Num examples: Unknown")
+ logger.info(f" Batch size = {batch_size}")
+
+ if hasattr(model, "eval") and callable(model.eval):
+ model.eval()
+ if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
+ self.optimizer.eval()
+
+ self.callback_handler.eval_dataloader = dataloader
+ # Do this before wrapping.
+ eval_dataset = getattr(dataloader, "dataset", None)
+
+ if args.past_index >= 0:
+ self._past = None
+
+ # Initialize containers
+ all_losses = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
+ all_preds = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
+ all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
+ all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
+
+ metrics = None
+ eval_set_kwargs = {}
+
+ # Will be useful when we have an iterable dataset so don't know its length.
+ observed_num_examples = 0
+
+ # Main evaluation loop
+ for step, inputs in enumerate(dataloader):
+ # Update the observed num examples
+ observed_batch_size = find_batch_size(inputs)
+ if observed_batch_size is not None:
+ observed_num_examples += observed_batch_size
+ # For batch samplers, batch_size is not known by the dataloader in advance.
+ if batch_size is None:
+ batch_size = observed_batch_size
+
+ # Prediction step
+ losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
+ main_input_name = getattr(self.model, "main_input_name", "input_ids")
+ inputs_decode = (
+ self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None
+ )
+
+ if is_torch_xla_available():
+ xm.mark_step()
+
+ # Update containers
+ if losses is not None:
+ losses = self.gather_function(losses.repeat(batch_size))
+ all_losses.add(losses)
+ if inputs_decode is not None:
+ inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
+ inputs_decode = self.gather_function(inputs_decode)
+ if not self.args.batch_eval_metrics or description == "Prediction":
+ all_inputs.add(inputs_decode)
+ if labels is not None:
+ # Pad labels here, preparing for preprocess_logits_for_metrics in next logits block.
+ labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
+ if logits is not None:
+ logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
+ if self.preprocess_logits_for_metrics is not None:
+ logits = self.preprocess_logits_for_metrics(logits, labels)
+ logits = self.gather_function(logits)
+ if not self.args.batch_eval_metrics or description == "Prediction":
+ all_preds.add(logits)
+ if labels is not None:
+ labels = self.gather_function(labels)
+ if not self.args.batch_eval_metrics or description == "Prediction":
+ all_labels.add(labels)
+
+ self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
+
+ if self.args.batch_eval_metrics:
+ if self.compute_metrics is not None and logits is not None and labels is not None:
+ is_last_step = self.accelerator.gradient_state.end_of_dataloader
+ batch_kwargs = {}
+ batch_kwargs["losses"] = losses if "loss" in args.include_for_metrics else None
+ batch_kwargs["inputs"] = inputs if "inputs" in args.include_for_metrics else None
+ metrics = self.compute_metrics(
+ EvalPrediction(predictions=logits, label_ids=labels, **batch_kwargs),
+ compute_result=is_last_step,
+ )
+
+ del losses, logits, labels, inputs
+ torch.cuda.empty_cache()
+
+ # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
+ elif args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
+ all_losses.to_cpu_and_numpy()
+ all_preds.to_cpu_and_numpy()
+ all_labels.to_cpu_and_numpy()
+ all_inputs.to_cpu_and_numpy()
+
+ del losses, logits, labels, inputs
+ torch.cuda.empty_cache()
+
+ # After all calls to `.gather_function`, reset to `gather_for_metrics`:
+ self.gather_function = self.accelerator.gather_for_metrics
+ if args.past_index and hasattr(self, "_past"):
+ # Clean the state at the end of the evaluation loop
+ delattr(self, "_past")
+
+ # Gather all remaining tensors and put them back on the CPU
+ all_losses = all_losses.get_arrays()
+ all_preds = all_preds.get_arrays()
+ all_labels = all_labels.get_arrays()
+ all_inputs = all_inputs.get_arrays()
+
+ # Number of samples
+ if has_length(eval_dataset):
+ num_samples = len(eval_dataset)
+ # The instance check is weird and does not actually check for the type, but whether the dataset has the right
+ # methods. Therefore we need to make sure it also has the attribute.
+ elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
+ num_samples = eval_dataset.num_examples
+ else:
+ if has_length(dataloader):
+ num_samples = self.num_examples(dataloader)
+ else: # both len(dataloader.dataset) and len(dataloader) fail
+ num_samples = observed_num_examples
+ if num_samples == 0 and observed_num_examples > 0:
+ num_samples = observed_num_examples
+
+ # Metrics!
+ if (
+ self.compute_metrics is not None
+ and all_preds is not None
+ and all_labels is not None
+ and not self.args.batch_eval_metrics
+ ):
+ eval_set_kwargs["losses"] = all_losses if "loss" in args.include_for_metrics else None
+ eval_set_kwargs["inputs"] = all_inputs if "inputs" in args.include_for_metrics else None
+ metrics = self.compute_metrics(
+ EvalPrediction(predictions=all_preds, label_ids=all_labels, **eval_set_kwargs)
+ )
+ elif metrics is None:
+ metrics = {}
+
+ # To be JSON-serializable, we need to remove numpy types or zero-d tensors
+ metrics = denumpify_detensorize(metrics)
+
+ if isinstance(all_losses, list) and all_losses:
+ metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()
+ elif isinstance(all_losses, np.ndarray):
+ metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
+ if hasattr(self, "jit_compilation_time"):
+ metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
+ if hasattr(self, "model_preparation_time"):
+ metrics[f"{metric_key_prefix}_model_preparation_time"] = self.model_preparation_time
+
+ # Prefix all keys with metric_key_prefix + '_'
+ for key in list(metrics.keys()):
+ if not key.startswith(f"{metric_key_prefix}_"):
+ metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
+
+ return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
+
+ def _nested_gather(self, tensors, name=None):
+ """
+ Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
+ concatenating them to `gathered`
+ """
+ if tensors is None:
+ return
+ if is_torch_xla_available():
+ if name is None:
+ name = "nested_gather"
+ tensors = nested_xla_mesh_reduce(tensors, name)
+ elif is_sagemaker_mp_enabled():
+ tensors = smp_gather(tensors)
+ elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or (
+ self.args.distributed_state is None and self.args.local_rank != -1
+ ):
+ tensors = distributed_concat(tensors)
+ return tensors
+
+ def prediction_step(
+ self,
+ model: nn.Module,
+ inputs: dict[str, Union[torch.Tensor, Any]],
+ prediction_loss_only: bool,
+ ignore_keys: Optional[list[str]] = None,
+ ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
+ """
+ Perform an evaluation step on `model` using `inputs`.
+
+ Subclass and override to inject custom behavior.
+
+ Args:
+ model (`nn.Module`):
+ The model to evaluate.
+ inputs (`dict[str, Union[torch.Tensor, Any]]`):
+ The inputs and targets of the model.
+
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
+ argument `labels`. Check your model's documentation for all accepted arguments.
+ prediction_loss_only (`bool`):
+ Whether or not to return the loss only.
+ ignore_keys (`list[str]`, *optional*):
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+ gathering predictions.
+
+ Return:
+ tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
+ logits and labels (each being optional).
+ """
+ has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
+ # For CLIP-like models capable of returning loss values.
+ # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
+ # is `True` in `model.forward`.
+ return_loss = inputs.get("return_loss")
+ if return_loss is None:
+ return_loss = self.can_return_loss
+ loss_without_labels = len(self.label_names) == 0 and return_loss
+
+ inputs = self._prepare_inputs(inputs)
+ if ignore_keys is None:
+ if hasattr(self.model, "config"):
+ ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", ["past_key_values"])
+ else:
+ ignore_keys = []
+
+ # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
+ if has_labels or loss_without_labels:
+ labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
+ if len(labels) == 1:
+ labels = labels[0]
+ else:
+ labels = None
+
+ with torch.no_grad():
+ if is_sagemaker_mp_enabled():
+ raw_outputs = smp_forward_only(model, inputs)
+ if has_labels or loss_without_labels:
+ if isinstance(raw_outputs, dict):
+ loss_mb = raw_outputs["loss"]
+ logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
+ else:
+ loss_mb = raw_outputs[0]
+ logits_mb = raw_outputs[1:]
+
+ loss = loss_mb.reduce_mean().detach().cpu()
+ logits = smp_nested_concat(logits_mb)
+ else:
+ loss = None
+ if isinstance(raw_outputs, dict):
+ logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
+ else:
+ logits_mb = raw_outputs
+ logits = smp_nested_concat(logits_mb)
+ else:
+ if has_labels or loss_without_labels:
+ with self.compute_loss_context_manager():
+ num_items_in_batch = self._get_num_items_in_batch([inputs], self.args.device)
+ loss, outputs = self.compute_loss(
+ model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
+ )
+ loss = loss.detach().mean()
+
+ if isinstance(outputs, dict):
+ logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
+ else:
+ logits = outputs[1:]
+ else:
+ loss = None
+ with self.compute_loss_context_manager():
+ outputs = model(**inputs)
+ if isinstance(outputs, dict):
+ logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
+ else:
+ logits = outputs
+ # TODO: this needs to be fixed and made cleaner later.
+ if self.args.past_index >= 0:
+ self._past = outputs[self.args.past_index - 1]
+
+ if prediction_loss_only:
+ return (loss, None, None)
+
+ logits = nested_detach(logits)
+ if len(logits) == 1:
+ logits = logits[0]
+
+ return (loss, logits, labels)
+
+ def floating_point_ops(self, inputs: dict[str, Union[torch.Tensor, Any]]):
+ """
+ For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point
+ operations for every backward + forward pass. If using another model, either implement such a method in the
+ model or subclass and override this method.
+
+ Args:
+ inputs (`dict[str, Union[torch.Tensor, Any]]`):
+ The inputs and targets of the model.
+
+ Returns:
+ `int`: The number of floating-point operations.
+ """
+ if hasattr(self.model, "floating_point_ops"):
+ return self.model.floating_point_ops(inputs)
+ else:
+ return 0
+
+ def init_hf_repo(self, token: Optional[str] = None):
+ """
+ Initializes a git repo in `self.args.hub_model_id`.
+ """
+ # Only on process zero
+ if not self.is_world_process_zero():
+ return
+
+ if self.args.hub_model_id is None:
+ repo_name = Path(self.args.output_dir).absolute().name
+ else:
+ repo_name = self.args.hub_model_id
+
+ token = token if token is not None else self.args.hub_token
+ repo_url = create_repo(repo_name, token=token, private=self.args.hub_private_repo, exist_ok=True)
+ self.hub_model_id = repo_url.repo_id
+ self.push_in_progress = None
+
+ def create_model_card(
+ self,
+ language: Optional[str] = None,
+ license: Optional[str] = None,
+ tags: Union[str, list[str], None] = None,
+ model_name: Optional[str] = None,
+ finetuned_from: Optional[str] = None,
+ tasks: Union[str, list[str], None] = None,
+ dataset_tags: Union[str, list[str], None] = None,
+ dataset: Union[str, list[str], None] = None,
+ dataset_args: Union[str, list[str], None] = None,
+ ):
+ """
+ Creates a draft of a model card using the information available to the `Trainer`.
+
+ Args:
+ language (`str`, *optional*):
+ The language of the model (if applicable)
+ license (`str`, *optional*):
+ The license of the model. Will default to the license of the pretrained model used, if the original
+ model given to the `Trainer` comes from a repo on the Hub.
+ tags (`str` or `list[str]`, *optional*):
+ Some tags to be included in the metadata of the model card.
+ model_name (`str`, *optional*):
+ The name of the model.
+ finetuned_from (`str`, *optional*):
+ The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
+ of the original model given to the `Trainer` (if it comes from the Hub).
+ tasks (`str` or `list[str]`, *optional*):
+ One or several task identifiers, to be included in the metadata of the model card.
+ dataset_tags (`str` or `list[str]`, *optional*):
+ One or several dataset tags, to be included in the metadata of the model card.
+ dataset (`str` or `list[str]`, *optional*):
+ One or several dataset identifiers, to be included in the metadata of the model card.
+ dataset_args (`str` or `list[str]`, *optional*):
+ One or several dataset arguments, to be included in the metadata of the model card.
+ """
+ if not self.is_world_process_zero():
+ return
+
+ model_card_filepath = os.path.join(self.args.output_dir, "README.md")
+ is_peft_library = False
+ if os.path.exists(model_card_filepath):
+ library_name = ModelCard.load(model_card_filepath).data.get("library_name")
+ is_peft_library = library_name == "peft"
+
+ # Append existing tags in `tags`
+ existing_tags = ModelCard.load(model_card_filepath).data.tags
+ if tags is not None and existing_tags is not None:
+ if isinstance(tags, str):
+ tags = [tags]
+ for tag in existing_tags:
+ if tag not in tags:
+ tags.append(tag)
+
+ training_summary = TrainingSummary.from_trainer(
+ self,
+ language=language,
+ license=license,
+ tags=tags,
+ model_name=model_name,
+ finetuned_from=finetuned_from,
+ tasks=tasks,
+ dataset_tags=dataset_tags,
+ dataset=dataset,
+ dataset_args=dataset_args,
+ )
+ model_card = training_summary.to_model_card()
+ with open(model_card_filepath, "w") as f:
+ f.write(model_card)
+
+ if is_peft_library:
+ self.accelerator.unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)
+
+ def _push_from_checkpoint(self, checkpoint_folder):
+ # Only push from one node.
+ if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
+ return
+ # If we haven't finished the last push, we don't do this one unless args.hub_always_push=True.
+ if not self.args.hub_always_push and self.push_in_progress is not None and not self.push_in_progress.is_done():
+ return
+
+ output_dir = self.args.output_dir
+ # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
+ modeling_files = [CONFIG_NAME, GENERATION_CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
+ # Add sharded checkpoints if we have an index
+ for index_file in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
+ index_path = os.path.join(checkpoint_folder, index_file)
+ if os.path.isfile(index_path):
+ modeling_files.append(index_file)
+ with open(index_path) as f:
+ index = json.loads(f.read())
+ shard_files = list(set(index["weight_map"].values()))
+ modeling_files.extend(shard_files)
+ if is_peft_available():
+ modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])
+ for modeling_file in modeling_files:
+ if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
+ shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
+ # Saving the processing class is fast and we don't know how many files it may have spawned, so we resave it to be sure.
+ if self.processing_class is not None:
+ self.processing_class.save_pretrained(output_dir)
+ # Same for the training arguments
+ torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
+
+ if self.args.save_strategy == SaveStrategy.STEPS:
+ commit_message = f"Training in progress, step {self.state.global_step}"
+ else:
+ commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
+
+ model_push_job = upload_folder(
+ repo_id=self.hub_model_id,
+ folder_path=output_dir,
+ commit_message=commit_message,
+ token=self.args.hub_token,
+ run_as_future=True,
+ ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
+ revision=self.args.hub_revision,
+ )
+
+ push_jobs = [model_push_job]
+
+ if self.args.hub_strategy in [HubStrategy.CHECKPOINT, HubStrategy.ALL_CHECKPOINTS]:
+ path_in_repo = (
+ "last-checkpoint" if self.args.hub_strategy == HubStrategy.CHECKPOINT else Path(checkpoint_folder).name
+ )
+ checkpoint_push = upload_folder(
+ repo_id=self.hub_model_id,
+ folder_path=checkpoint_folder,
+ path_in_repo=path_in_repo,
+ commit_message=commit_message + ", checkpoint",
+ token=self.args.hub_token,
+ run_as_future=True,
+ revision=self.args.hub_revision,
+ )
+ push_jobs.append(checkpoint_push)
+
+ if self.push_in_progress is None or self.push_in_progress.is_done():
+ self.push_in_progress = PushInProgress(push_jobs)
+ else:
+ self.push_in_progress.jobs.extend(push_jobs)
+
+ def _finish_current_push(self):
+ if not hasattr(self, "push_in_progress"):
+ return
+ if self.push_in_progress is not None and not self.push_in_progress.is_done():
+ logger.info("Waiting for the current checkpoint push to be finished, this might take a couple of minutes.")
+ self.push_in_progress.wait_until_done()
+
+ def push_to_hub(
+ self,
+ commit_message: Optional[str] = "End of training",
+ blocking: bool = True,
+ token: Optional[str] = None,
+ revision: Optional[str] = None,
+ **kwargs,
+ ) -> str:
+ """
+ Upload `self.model` and `self.processing_class` to the 🤗 model hub on the repo `self.args.hub_model_id`.
+
+ Parameters:
+ commit_message (`str`, *optional*, defaults to `"End of training"`):
+ Message to commit while pushing.
+ blocking (`bool`, *optional*, defaults to `True`):
+ Whether the function should return only when the `git push` has finished.
+ token (`str`, *optional*, defaults to `None`):
+ Token with write permission to overwrite Trainer's original args.
+ revision (`str`, *optional*):
+ The git revision to commit from. Defaults to the head of the "main" branch.
+ kwargs (`dict[str, Any]`, *optional*):
+ Additional keyword arguments passed along to [`~Trainer.create_model_card`].
+
+ Returns:
+ The URL of the repository where the model was pushed if `blocking=False`, or a `Future` object tracking the
+ progress of the commit if `blocking=True`.
+ """
+ model_name = kwargs.pop("model_name", None)
+ if model_name is None and self.args.should_save:
+ if self.args.hub_model_id is None:
+ model_name = Path(self.args.output_dir).name
+ else:
+ model_name = self.args.hub_model_id.split("/")[-1]
+ token = token if token is not None else self.args.hub_token
+
+ # In case the user calls this method with args.push_to_hub = False
+ if self.hub_model_id is None:
+ self.init_hf_repo(token=token)
+
+ # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
+ # self.args.should_save.
+ self.save_model(_internal_call=True)
+
+ # Only push from one node.
+ if not self.is_world_process_zero():
+ return
+
+ # Add additional tags in the case the model has already some tags and users pass
+ # "tags" argument to `push_to_hub` so that trainer automatically handles internal tags
+ # from all models since Trainer does not call `model.push_to_hub`.
+ if getattr(self.model, "model_tags", None) is not None:
+ if "tags" not in kwargs:
+ kwargs["tags"] = []
+
+ # If it is a string, convert it to a list
+ if isinstance(kwargs["tags"], str):
+ kwargs["tags"] = [kwargs["tags"]]
+
+ for model_tag in self.model.model_tags:
+ if model_tag not in kwargs["tags"]:
+ kwargs["tags"].append(model_tag)
+
+ self.create_model_card(model_name=model_name, **kwargs)
+
+ if revision is None:
+ revision = self.args.hub_revision
+
+ # Wait for the current upload to be finished.
+ self._finish_current_push()
+
+ return upload_folder(
+ repo_id=self.hub_model_id,
+ folder_path=self.args.output_dir,
+ commit_message=commit_message,
+ token=token,
+ run_as_future=not blocking,
+ ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
+ revision=revision,
+ )
+
+ #
+ # Deprecated code
+ #
+
+ def prediction_loop(
+ self,
+ dataloader: DataLoader,
+ description: str,
+ prediction_loss_only: Optional[bool] = None,
+ ignore_keys: Optional[list[str]] = None,
+ metric_key_prefix: str = "eval",
+ ) -> EvalLoopOutput:
+ """
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
+
+ Works both with or without labels.
+ """
+ args = self.args
+
+ if not has_length(dataloader):
+ raise ValueError("dataloader must implement a working __len__")
+
+ prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
+
+ # if eval is called w/o train, handle model prep here
+ if self.is_deepspeed_enabled and self.deepspeed is None:
+ _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
+
+ model = self._wrap_model(self.model, training=False, dataloader=dataloader)
+
+ if len(self.accelerator._models) == 0 and model is self.model:
+ model = (
+ self.accelerator.prepare(model)
+ if self.is_deepspeed_enabled or self.is_fsdp_enabled
+ else self.accelerator.prepare_model(model, evaluation_mode=True)
+ )
+
+ if self.is_fsdp_enabled:
+ self.model = model
+
+ # for the rest of this function `model` is the outside model, whether it was wrapped or not
+ if model is not self.model:
+ self.model_wrapped = model
+
+ # backward compatibility
+ if self.is_deepspeed_enabled:
+ self.deepspeed = self.model_wrapped
+
+ # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
+ # while ``train`` is running, cast it to the right dtype first and then put on device
+ if not self.is_in_train:
+ if args.fp16_full_eval:
+ model = model.to(dtype=torch.float16, device=args.device)
+ elif args.bf16_full_eval:
+ model = model.to(dtype=torch.bfloat16, device=args.device)
+
+ batch_size = (
+ dataloader.total_batch_size
+ if getattr(dataloader, "_is_accelerate_prepared", False)
+ else dataloader.batch_size
+ )
+
+ if batch_size is None:
+ raise ValueError(
+ "Batch size cannot be None. Ensure the dataloader has a valid batch_size or total_batch_size."
+ )
+
+ num_examples = self.num_examples(dataloader)
+ logger.info(f"\n***** Running {description} *****")
+ logger.info(f" Num examples = {num_examples}")
+ logger.info(f" Batch size = {batch_size}")
+
+ losses_host: Optional[torch.Tensor] = None
+ preds_host: Union[torch.Tensor, list[torch.Tensor], None] = None
+ labels_host: Union[torch.Tensor, list[torch.Tensor], None] = None
+ inputs_host: Union[torch.Tensor, list[torch.Tensor], None] = None
+ metrics: Optional[dict] = None
+ eval_set_kwargs: dict = {}
+
+ world_size = max(1, args.world_size)
+
+ eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
+ if not prediction_loss_only:
+ # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
+ # a batch size to the sampler)
+ make_multiple_of = None
+ if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
+ make_multiple_of = dataloader.sampler.batch_size
+ preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
+ labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
+ inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
+
+ model.eval()
+ if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
+ self.optimizer.eval()
+
+ if args.past_index >= 0:
+ self._past = None
+
+ self.callback_handler.eval_dataloader = dataloader
+
+ for step, inputs in enumerate(dataloader):
+ loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
+ main_input_name = getattr(self.model, "main_input_name", "input_ids")
+ inputs_decode = (
+ self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None
+ )
+
+ if loss is not None:
+ losses = loss.repeat(batch_size)
+ losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
+ if logits is not None:
+ preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
+ if labels is not None:
+ labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
+ if inputs_decode is not None:
+ inputs_host = (
+ inputs_decode
+ if inputs_host is None
+ else nested_concat(inputs_host, inputs_decode, padding_index=-100)
+ )
+ self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
+
+ if self.args.batch_eval_metrics:
+ if self.compute_metrics is not None and preds_host is not None and labels_host is not None:
+ is_last_step = self.accelerator.gradient_state.end_of_dataloader
+ batch_kwargs = {}
+ batch_kwargs["losses"] = losses_host if "loss" in args.include_for_metrics else None
+ batch_kwargs["inputs"] = inputs_host if "inputs" in args.include_for_metrics else None
+ metrics = self.compute_metrics(
+ EvalPrediction(predictions=preds_host, label_ids=labels_host, **batch_kwargs),
+ compute_result=is_last_step,
+ )
+
+ if self.args.batch_eval_metrics or (
+ args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0
+ ):
+ # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
+ eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
+ if not prediction_loss_only:
+ preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
+ labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
+ inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
+
+ # Set back to None to begin a new accumulation
+ del losses_host, preds_host, labels_host, inputs_host
+ torch.cuda.empty_cache()
+ losses_host, preds_host, labels_host, inputs_host = None, None, None, None
+
+ if args.past_index and hasattr(self, "_past"):
+ # Clean the state at the end of the evaluation loop
+ delattr(self, "_past")
+
+ # Gather all remaining tensors and put them back on the CPU
+ eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
+ if not prediction_loss_only:
+ preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
+ labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
+ inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
+
+ eval_loss = eval_losses_gatherer.finalize()
+ preds = preds_gatherer.finalize() if not prediction_loss_only else None
+ label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
+ inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
+
+ if (
+ self.compute_metrics is not None
+ and preds is not None
+ and label_ids is not None
+ and not self.args.batch_eval_metrics
+ ):
+ eval_set_kwargs["losses"] = eval_loss if "loss" in args.include_for_metrics else None
+ eval_set_kwargs["inputs"] = inputs_ids if "inputs" in args.include_for_metrics else None
+ metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids, **eval_set_kwargs))
+ elif metrics is None:
+ metrics = {}
+
+ # To be JSON-serializable, we need to remove numpy types or zero-d tensors
+ metrics = denumpify_detensorize(metrics)
+
+ if eval_loss is not None:
+ metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
+
+ # Prefix all keys with metric_key_prefix + '_'
+ for key in list(metrics.keys()):
+ if not key.startswith(f"{metric_key_prefix}_"):
+ metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
+
+ return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples)
+
+ def _gather_and_numpify(self, tensors, name):
+ """
+ Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
+ concatenating them to `gathered`
+ """
+ if tensors is None:
+ return
+ if is_torch_xla_available():
+ tensors = nested_xla_mesh_reduce(tensors, name)
+ elif is_sagemaker_mp_enabled():
+ tensors = smp_gather(tensors)
+ elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+ tensors = distributed_concat(tensors)
+
+ return nested_numpify(tensors)
+
+ def _add_sm_patterns_to_gitignore(self) -> None:
+ """Add SageMaker Checkpointing patterns to .gitignore file."""
+ # Make sure we only do this on the main process
+ if not self.is_world_process_zero():
+ return
+
+ patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"]
+
+ # Get current .gitignore content
+ if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")):
+ with open(os.path.join(self.repo.local_dir, ".gitignore")) as f:
+ current_content = f.read()
+ else:
+ current_content = ""
+
+ # Add the patterns to .gitignore
+ content = current_content
+ for pattern in patterns:
+ if pattern not in content:
+ if content.endswith("\n"):
+ content += pattern
+ else:
+ content += f"\n{pattern}"
+
+ # Write the .gitignore file if it has changed
+ if content != current_content:
+ with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f:
+ logger.debug(f"Writing .gitignore file. Content: {content}")
+ f.write(content)
+
+ self.repo.git_add(".gitignore")
+
+ # avoid race condition with git status
+ time.sleep(0.5)
+
+ if not self.repo.is_repo_clean():
+ self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
+ self.repo.git_push()
+
+ def create_accelerator_and_postprocess(self):
+ # We explicitly don't rely on the `Accelerator` to do gradient accumulation
+ grad_acc_kwargs = {}
+ if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:
+ grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs
+
+ # check if num_steps is attempted to be passed in gradient_accumulation_kwargs
+ if "num_steps" in grad_acc_kwargs:
+ if self.args.gradient_accumulation_steps > 1:
+ # raise because we do not know which setting is intended.
+ raise ValueError(
+ "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
+ "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
+ )
+ else:
+ self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"]
+
+ accelerator_config = self.args.accelerator_config.to_dict()
+
+ if is_accelerate_available("0.28.0"):
+ # Extract dataloader config params from accelerator config
+ dataloader_params = ["split_batches", "dispatch_batches", "even_batches", "use_seedable_sampler"]
+ dataloader_config = DataLoaderConfiguration(
+ **{param: accelerator_config.pop(param) for param in dataloader_params}
+ )
+ if is_accelerate_available("1.1.0"):
+ dataloader_config.data_seed = self.args.data_seed
+
+ non_blocking = accelerator_config.pop("non_blocking")
+ if not is_accelerate_available("0.30.0"):
+ if non_blocking:
+ raise ImportError(
+ "`non_blocking` is only supported in accelerate v0.30.0 and above. Please upgrade accelerate to use this feature."
+ )
+ else:
+ if non_blocking and not self.args.dataloader_pin_memory:
+ logger.warning(
+ "`non_blocking` is enabled but `dataloader_pin_memory` is not. For the best performance, it's recommended to enable both."
+ )
+ dataloader_config.non_blocking = non_blocking
+ # this would have been updated above, no need for it anymore
+ accelerator_config.pop("gradient_accumulation_kwargs")
+
+ args = {
+ "deepspeed_plugin": self.args.deepspeed_plugin,
+ }
+
+ # We defer compatibility checks to accelerator
+ if self.args.parallelism_config is not None:
+ if not is_accelerate_available("1.10.1"):
+ raise ImportError(
+ "ParallelismConfig requires accelerate v1.10.1 and above. Please upgrade accelerate to use this feature."
+ )
+
+ args["parallelism_config"] = self.args.parallelism_config
+
+ if is_accelerate_available("0.28.0"):
+ args["dataloader_config"] = dataloader_config
+ else:
+ args.update(accelerator_config)
+ # tp is initialized at Accelerator init phase so
+ # args should be prepared here
+ if hasattr(self.model, "tp_size") and self.model.tp_size is not None and self.model.tp_size > 1:
+ self.is_tp_enabled = True
+ if version.parse(accelerate_version) > version.parse("1.3.0"):
+ args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.model.tp_size)
+ else:
+ raise ValueError("Requires accelerate>1.3.0 to use Tensor Parallelism.")
+
+ # create accelerator object
+ self.accelerator = Accelerator(**args)
+ # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
+ self.gather_function = self.accelerator.gather_for_metrics
+
+ if "use_gather_object" in inspect.signature(self.gather_function).parameters:
+ self.gather_function = functools.partial(
+ self.gather_function, use_gather_object=self.args.eval_use_gather_object
+ )
+
+ # deepspeed and accelerate flags covering both trainer args and accelerate launcher
+ self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
+ self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
+ self.is_tp_enabled = getattr(self.accelerator.state, "torch_tp_plugin", None) is not None
+ # post accelerator creation setup
+ if self.is_fsdp_enabled:
+ fsdp_plugin = self.accelerator.state.fsdp_plugin
+ for param in ["limit_all_gathers", "activation_checkpointing"]:
+ setattr(fsdp_plugin, param, self.args.fsdp_config.get(param, getattr(fsdp_plugin, param)))
+ if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
+ raise ValueError(
+ "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
+ "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
+ "when using FSDP."
+ )
+
+ if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None:
+ self.propagate_args_to_deepspeed()
+
+ # `save_only_model` can't be used with DeepSpeed/FSDP along with `load_best_model_at_end`
+ if (
+ self.args.save_only_model
+ and (self.is_deepspeed_enabled or self.is_fsdp_enabled)
+ and self.args.load_best_model_at_end
+ ):
+ wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP"
+ raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.")
+
+ # `auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3
+ if (
+ self.is_deepspeed_enabled
+ and self.accelerator.state.deepspeed_plugin.zero_stage == 3
+ and self.args.auto_find_batch_size
+ ):
+ raise ValueError(
+ "`auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3. Please consider using Zero-2, Zero-1, or FSDP"
+ )
+ if (
+ self.args.save_only_model
+ and self.is_fsdp_enabled
+ and "SHARDED_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)
+ ):
+ raise ValueError("save_only_model option is not compatible with FSDP state dict type 'SHARDED_STATE_DICT'")
+
+ def propagate_args_to_deepspeed(self, auto_find_batch_size=False):
+ """
+ Sets values in the deepspeed plugin based on the Trainer args
+ """
+ from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
+
+ ds_plugin = self.accelerator.state.deepspeed_plugin
+
+ ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
+ ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
+ ds_plugin.hf_ds_config.trainer_config_process(self.args, auto_find_batch_size)
+
+ def _fsdp_qlora_plugin_updates(self):
+ if self.is_fsdp_enabled and _is_peft_model(self.model):
+ from peft import PeftConfig
+ from peft.utils.other import fsdp_auto_wrap_policy
+
+ if isinstance(self.model.active_peft_config, PeftConfig):
+ self.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model)
+ if (
+ getattr(self.model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
+ and self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage.is_floating_point
+ and version.parse(accelerate_version) > version.parse("0.27.0")
+ ):
+ self.accelerator.state.fsdp_plugin.set_mixed_precision(
+ self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
+ )
+
+ def _get_num_items_in_batch(self, batch_samples: list, device: torch.device) -> Optional[Union[torch.Tensor, int]]:
+ """
+ Counts the number of items in the batches to properly scale the loss.
+ Args:
+ batch_samples (`list`): List of batches
+ device (`torch.device`): The device on which the number of items in the batch should be.
+ Returns:
+ None if the number of items in the batch doesn't need to be computed else the number of items in the batch
+ """
+ num_items_in_batch = None
+ count_num_items_in_batch = (
+ len(batch_samples) > 0
+ and "labels" in batch_samples[0]
+ and (
+ # num_items_in_batch is passed to model forward
+ # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3757
+ self.model_accepts_loss_kwargs
+ # num_items_in_batch is passed to compute_loss_func
+ # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3773
+ or self.compute_loss_func is not None
+ # num_items_in_batch is also verified if (self.model_accepts_loss_kwargs or self.compute_loss_func)
+ # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3790
+ )
+ )
+ if count_num_items_in_batch:
+ # For now we don't support object detection
+ try:
+ num_items_in_batch = sum((batch["labels"].ne(-100)).sum() for batch in batch_samples)
+ except (TypeError, AttributeError):
+ pass
+
+ if num_items_in_batch is not None:
+ if self.args.average_tokens_across_devices and self.args.world_size >= 1:
+ num_items_in_batch = self.accelerator.gather(num_items_in_batch.to(device)).sum()
+ elif self.args.n_gpu >= 1:
+ # In DP case, if we don't average, we need to divide by the number of gpu. This is the simplest approximation.
+ # Otherwise, we would have to scatter labels and calculate num_items_in_batch for each gpu.
+ num_items_in_batch = num_items_in_batch // self.args.n_gpu
+
+ if torch.is_tensor(num_items_in_batch):
+ num_items_in_batch = num_items_in_batch.to(device)
+
+ if self.args.n_gpu > 1 and num_items_in_batch.dim() == 0:
+ # In the DataParallel case, convert the scalar tensor into a 2-dim tensor with the same value repeated
+ num_items_in_batch = num_items_in_batch.unsqueeze(0).expand(self.args.n_gpu, -1)
+ # Divide by number of devices with the same batch
+ if pc := getattr(self.accelerator, "parallelism_config", None):
+ num_items_in_batch = num_items_in_batch // pc.non_data_parallel_size
+
+ return num_items_in_batch
+
+ def get_batch_samples(
+ self, epoch_iterator: Iterator, num_batches: int, device: torch.device
+ ) -> tuple[list, Optional[Union[torch.Tensor, int]]]:
+ """
+ Collects a specified number of batches from the epoch iterator and optionally counts the number of items in the batches to properly scale the loss.
+ """
+ batch_samples = []
+
+ for _ in range(num_batches):
+ try:
+ batch_samples.append(next(epoch_iterator))
+ except StopIteration:
+ break
+
+ num_items_in_batch = self._get_num_items_in_batch(batch_samples, device)
+ return batch_samples, num_items_in_batch
+
+ def set_initial_training_values(
+ self, args: TrainingArguments, dataloader: DataLoader, total_train_batch_size: int
+ ):
+ """
+ Calculates and returns the following values:
+ - `num_train_epochs`
+ - `num_update_steps_per_epoch`
+ - `num_examples`
+ - `num_train_samples`
+ - `epoch_based`
+ - `len_dataloader`
+ - `max_steps`
+ """
+ # Case 1: we rely on `args.max_steps` first
+ max_steps = args.max_steps
+ # If max_steps is negative, we use the number of epochs to determine the number of total steps later
+ epoch_based = max_steps < 0
+ len_dataloader = len(dataloader) if has_length(dataloader) else None
+
+ # Case 2: We have a dataloader length and can extrapolate
+ if len_dataloader is not None:
+ num_update_steps_per_epoch = max(
+ len_dataloader // args.gradient_accumulation_steps
+ + int(len_dataloader % args.gradient_accumulation_steps > 0),
+ 1,
+ )
+ # Case 3: We have a length but are using epochs, we can extrapolate the number of steps
+ if epoch_based:
+ max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
+
+ # Now we figure out `num_examples`, `num_train_epochs`, and `train_samples`
+ if len_dataloader:
+ num_examples = self.num_examples(dataloader)
+ if args.max_steps > 0:
+ num_train_epochs = max_steps // num_update_steps_per_epoch + int(
+ max_steps % num_update_steps_per_epoch > 0
+ )
+ # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
+ # the best we can do.
+ num_train_samples = max_steps * total_train_batch_size
+ else:
+ num_train_epochs = math.ceil(args.num_train_epochs)
+ num_train_samples = self.num_examples(dataloader) * args.num_train_epochs
+ elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
+ num_train_epochs = sys.maxsize
+ num_update_steps_per_epoch = max_steps
+ num_examples = total_train_batch_size * args.max_steps
+ num_train_samples = args.max_steps * total_train_batch_size
+ else:
+ raise ValueError(
+ "args.max_steps must be set to a positive value if dataloader does not have a length, was"
+ f" {args.max_steps}"
+ )
+ return (
+ num_train_epochs,
+ num_update_steps_per_epoch,
+ num_examples,
+ num_train_samples,
+ epoch_based,
+ len_dataloader,
+ max_steps,
+ )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/trainer_callback.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/trainer_callback.py
new file mode 100644
index 0000000000000000000000000000000000000000..c72bdbb70bcd189582b8b69cd121be6a10c5c4e8
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/trainer_callback.py
@@ -0,0 +1,785 @@
+# Copyright 2020-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Callbacks to use with the Trainer class and customize the training loop.
+"""
+
+import dataclasses
+import json
+import math
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import numpy as np
+from tqdm.auto import tqdm
+
+from .trainer_utils import HPSearchBackend, IntervalStrategy, SaveStrategy, has_length
+from .training_args import TrainingArguments
+from .utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class TrainerState:
+ """
+ A class containing the [`Trainer`] inner state that will be saved along the model and optimizer when checkpointing
+ and passed to the [`TrainerCallback`].
+
+
+
+ In all this class, one step is to be understood as one update step. When using gradient accumulation, one update
+ step may require several forward and backward passes: if you use `gradient_accumulation_steps=n`, then one update
+ step requires going through *n* batches.
+
+
+
+ Args:
+ epoch (`float`, *optional*):
+ Only set during training, will represent the epoch the training is at (the decimal part being the
+ percentage of the current epoch completed).
+ global_step (`int`, *optional*, defaults to 0):
+ During training, represents the number of update steps completed.
+ max_steps (`int`, *optional*, defaults to 0):
+ The number of update steps to do during the current training.
+ logging_steps (`int`, *optional*, defaults to 500):
+ Log every X updates steps
+ eval_steps (`int`, *optional*):
+ Run an evaluation every X steps.
+ save_steps (`int`, *optional*, defaults to 500):
+ Save checkpoint every X updates steps.
+ train_batch_size (`int`, *optional*):
+ The batch size for the training dataloader. Only needed when
+ `auto_find_batch_size` has been used.
+ num_input_tokens_seen (`int`, *optional*, defaults to 0):
+ When tracking the inputs tokens, the number of tokens seen during training (number of input tokens, not the
+ number of prediction tokens).
+ total_flos (`float`, *optional*, defaults to 0):
+ The total number of floating operations done by the model since the beginning of training (stored as floats
+ to avoid overflow).
+ log_history (`list[dict[str, float]]`, *optional*):
+ The list of logs done since the beginning of training.
+ best_metric (`float`, *optional*):
+ When tracking the best model, the value of the best metric encountered so far.
+ best_global_step (`int`, *optional*):
+ When tracking the best model, the step at which the best metric was encountered.
+ Used for setting `best_model_checkpoint`.
+ best_model_checkpoint (`str`, *optional*):
+ When tracking the best model, the value of the name of the checkpoint for the best model encountered so
+ far.
+ is_local_process_zero (`bool`, *optional*, defaults to `True`):
+ Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
+ several machines) main process.
+ is_world_process_zero (`bool`, *optional*, defaults to `True`):
+ Whether or not this process is the global main process (when training in a distributed fashion on several
+ machines, this is only going to be `True` for one process).
+ is_hyper_param_search (`bool`, *optional*, defaults to `False`):
+ Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will
+ impact the way data will be logged in TensorBoard.
+ stateful_callbacks (`list[StatefulTrainerCallback]`, *optional*):
+ Callbacks attached to the `Trainer` that should have their states be saved or restored.
+ Relevant callbacks should implement a `state` and `from_state` function.
+ """
+
+ epoch: Optional[float] = None
+ global_step: int = 0
+ max_steps: int = 0
+ logging_steps: int = 500
+ eval_steps: int = 500
+ save_steps: int = 500
+ train_batch_size: Optional[int] = None
+ num_train_epochs: int = 0
+ num_input_tokens_seen: int = 0
+ total_flos: float = 0
+ log_history: list[dict[str, float]] = None
+ best_metric: Optional[float] = None
+ best_global_step: Optional[int] = None
+ best_model_checkpoint: Optional[str] = None
+ is_local_process_zero: bool = True
+ is_world_process_zero: bool = True
+ is_hyper_param_search: bool = False
+ trial_name: Optional[str] = None
+ trial_params: Optional[dict[str, Union[str, float, int, bool]]] = None
+ stateful_callbacks: Optional[list["TrainerCallback"]] = None
+
+ def __post_init__(self):
+ if self.log_history is None:
+ self.log_history = []
+ if self.stateful_callbacks is None:
+ self.stateful_callbacks = {}
+ elif isinstance(self.stateful_callbacks, dict):
+ # We are loading the callbacks in from the state file, no need to process them
+ pass
+ else:
+ # Saveable callbacks get stored as dict of kwargs
+ stateful_callbacks = {}
+ for callback in self.stateful_callbacks:
+ if not isinstance(callback, (ExportableState)):
+ raise TypeError(
+ f"All callbacks passed to be saved must inherit `ExportableState`, but received {type(callback)}"
+ )
+ name = callback.__class__.__name__
+ if name in stateful_callbacks:
+ # We can have multiple versions of the same callback
+ # if so, we store them as a list of states to restore
+ if not isinstance(stateful_callbacks[name], list):
+ stateful_callbacks[name] = [stateful_callbacks[name]]
+ stateful_callbacks[name].append(callback.state())
+ else:
+ stateful_callbacks[name] = callback.state()
+ self.stateful_callbacks = stateful_callbacks
+
+ def save_to_json(self, json_path: str):
+ """Save the content of this instance in JSON format inside `json_path`."""
+ json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
+ with open(json_path, "w", encoding="utf-8") as f:
+ f.write(json_string)
+
+ @classmethod
+ def load_from_json(cls, json_path: str):
+ """Create an instance from the content of `json_path`."""
+ with open(json_path, encoding="utf-8") as f:
+ text = f.read()
+ return cls(**json.loads(text))
+
+ def compute_steps(self, args, max_steps):
+ """
+ Calculates and stores the absolute value for logging,
+ eval, and save steps based on if it was a proportion
+ or not.
+ """
+ for step_kind in ("logging", "eval", "save"):
+ num_steps = getattr(args, f"{step_kind}_steps")
+ if num_steps is not None:
+ if num_steps < 1:
+ num_steps = math.ceil(max_steps * num_steps)
+ setattr(self, f"{step_kind}_steps", num_steps)
+
+ def init_training_references(self, trainer, max_steps, num_train_epochs, trial):
+ """
+ Stores the initial training references needed in `self`
+ """
+ if trainer.hp_name is not None and trainer._trial is not None:
+ # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
+ # parameter to Train when using DDP.
+ self.trial_name = trainer.hp_name(trainer._trial)
+ self.trial_params = None
+ if trial is not None:
+ from transformers.integrations import hp_params
+
+ assignments = trial.assignments if trainer.hp_search_backend == HPSearchBackend.SIGOPT else trial
+ self.trial_params = hp_params(assignments)
+
+ self.max_steps = max_steps
+ self.num_train_epochs = num_train_epochs
+ self.is_local_process_zero = trainer.is_local_process_zero()
+ self.is_world_process_zero = trainer.is_world_process_zero()
+
+
+class ExportableState:
+ """
+ A class for objects that include the ability to have its state
+ be saved during `Trainer._save_checkpoint` and loaded back in during
+ `Trainer._load_from_checkpoint`.
+
+ These must implement a `state` function that gets called during the respective
+ Trainer function call. It should only include parameters and attributes needed to
+ recreate the state at a particular time, to avoid utilizing pickle/maintain standard
+ file IO writing.
+
+ Example:
+
+ ```python
+ class EarlyStoppingCallback(TrainerCallback, ExportableState):
+ def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
+ self.early_stopping_patience = early_stopping_patience
+ self.early_stopping_threshold = early_stopping_threshold
+ # early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
+ self.early_stopping_patience_counter = 0
+
+ def state(self) -> dict:
+ return {
+ "args": {
+ "early_stopping_patience": self.early_stopping_patience,
+ "early_stopping_threshold": self.early_stopping_threshold,
+ },
+ "attributes": {
+ "early_stopping_patience_counter": self.early_stopping_patience_counter,
+ }
+ }
+ ```"""
+
+ def state(self) -> dict:
+ raise NotImplementedError("You must implement a `state` function to utilize this class.")
+
+ @classmethod
+ def from_state(cls, state):
+ instance = cls(**state["args"])
+ for k, v in state["attributes"].items():
+ setattr(instance, k, v)
+ return instance
+
+
+@dataclass
+class TrainerControl(ExportableState):
+ """
+ A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some
+ switches in the training loop.
+
+ Args:
+ should_training_stop (`bool`, *optional*, defaults to `False`):
+ Whether or not the training should be interrupted.
+
+ If `True`, this variable will not be set back to `False`. The training will just stop.
+ should_epoch_stop (`bool`, *optional*, defaults to `False`):
+ Whether or not the current epoch should be interrupted.
+
+ If `True`, this variable will be set back to `False` at the beginning of the next epoch.
+ should_save (`bool`, *optional*, defaults to `False`):
+ Whether or not the model should be saved at this step.
+
+ If `True`, this variable will be set back to `False` at the beginning of the next step.
+ should_evaluate (`bool`, *optional*, defaults to `False`):
+ Whether or not the model should be evaluated at this step.
+
+ If `True`, this variable will be set back to `False` at the beginning of the next step.
+ should_log (`bool`, *optional*, defaults to `False`):
+ Whether or not the logs should be reported at this step.
+
+ If `True`, this variable will be set back to `False` at the beginning of the next step.
+ """
+
+ should_training_stop: bool = False
+ should_epoch_stop: bool = False
+ should_save: bool = False
+ should_evaluate: bool = False
+ should_log: bool = False
+
+ def _new_training(self):
+ """Internal method that resets the variable for a new training."""
+ self.should_training_stop = False
+
+ def _new_epoch(self):
+ """Internal method that resets the variable for a new epoch."""
+ self.should_epoch_stop = False
+
+ def _new_step(self):
+ """Internal method that resets the variable for a new step."""
+ self.should_save = False
+ self.should_evaluate = False
+ self.should_log = False
+
+ def state(self) -> dict:
+ return {
+ "args": {
+ "should_training_stop": self.should_training_stop,
+ "should_epoch_stop": self.should_epoch_stop,
+ "should_save": self.should_save,
+ "should_evaluate": self.should_evaluate,
+ "should_log": self.should_log,
+ },
+ "attributes": {},
+ }
+
+
+class TrainerCallback:
+ # no-format
+ """
+ A class for objects that will inspect the state of the training loop at some events and take some decisions. At
+ each of those events the following arguments are available:
+
+ Args:
+ args ([`TrainingArguments`]):
+ The training arguments used to instantiate the [`Trainer`].
+ state ([`TrainerState`]):
+ The current state of the [`Trainer`].
+ control ([`TrainerControl`]):
+ The object that is returned to the [`Trainer`] and can be used to make some decisions.
+ model ([`PreTrainedModel`] or `torch.nn.Module`):
+ The model being trained.
+ tokenizer ([`PreTrainedTokenizer`]):
+ The tokenizer used for encoding the data. This is deprecated in favour of `processing_class`.
+ processing_class ([`PreTrainedTokenizer` or `BaseImageProcessor` or `ProcessorMixin` or `FeatureExtractionMixin`]):
+ The processing class used for encoding the data. Can be a tokenizer, a processor, an image processor or a feature extractor.
+ optimizer (`torch.optim.Optimizer`):
+ The optimizer used for the training steps.
+ lr_scheduler (`torch.optim.lr_scheduler.LambdaLR`):
+ The scheduler used for setting the learning rate.
+ train_dataloader (`torch.utils.data.DataLoader`, *optional*):
+ The current dataloader used for training.
+ eval_dataloader (`torch.utils.data.DataLoader`, *optional*):
+ The current dataloader used for evaluation.
+ metrics (`dict[str, float]`):
+ The metrics computed by the last evaluation phase.
+
+ Those are only accessible in the event `on_evaluate`.
+ logs (`dict[str, float]`):
+ The values to log.
+
+ Those are only accessible in the event `on_log`.
+
+ The `control` object is the only one that can be changed by the callback, in which case the event that changes it
+ should return the modified version.
+
+ The argument `args`, `state` and `control` are positionals for all events, all the others are grouped in `kwargs`.
+ You can unpack the ones you need in the signature of the event using them. As an example, see the code of the
+ simple [`~transformers.PrinterCallback`].
+
+ Example:
+
+ ```python
+ class PrinterCallback(TrainerCallback):
+ def on_log(self, args, state, control, logs=None, **kwargs):
+ _ = logs.pop("total_flos", None)
+ if state.is_local_process_zero:
+ print(logs)
+ ```"""
+
+ def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called at the end of the initialization of the [`Trainer`].
+ """
+ pass
+
+ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called at the beginning of training.
+ """
+ pass
+
+ def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called at the end of training.
+ """
+ pass
+
+ def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called at the beginning of an epoch.
+ """
+ pass
+
+ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called at the end of an epoch.
+ """
+ pass
+
+ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called at the beginning of a training step. If using gradient accumulation, one training step might take
+ several inputs.
+ """
+ pass
+
+ def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called before the optimizer step but after gradient clipping. Useful for monitoring gradients.
+ """
+ pass
+
+ def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients.
+ """
+ pass
+
+ def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called at the end of an substep during gradient accumulation.
+ """
+ pass
+
+ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called at the end of a training step. If using gradient accumulation, one training step might take
+ several inputs.
+ """
+ pass
+
+ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called after an evaluation phase.
+ """
+ pass
+
+ def on_predict(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs):
+ """
+ Event called after a successful prediction.
+ """
+ pass
+
+ def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called after a checkpoint save.
+ """
+ pass
+
+ def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called after logging the last logs.
+ """
+ pass
+
+ def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called after a prediction step.
+ """
+ pass
+
+
+class CallbackHandler(TrainerCallback):
+ """Internal class that just calls the list of callbacks in order."""
+
+ def __init__(self, callbacks, model, processing_class, optimizer, lr_scheduler):
+ self.callbacks = []
+ for cb in callbacks:
+ self.add_callback(cb)
+ self.model = model
+ self.processing_class = processing_class
+ self.optimizer = optimizer
+ self.lr_scheduler = lr_scheduler
+ self.train_dataloader = None
+ self.eval_dataloader = None
+
+ if not any(isinstance(cb, DefaultFlowCallback) for cb in self.callbacks):
+ logger.warning(
+ "The Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You\n"
+ + "should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list of"
+ + "callbacks is\n:"
+ + self.callback_list
+ )
+
+ def add_callback(self, callback):
+ cb = callback() if isinstance(callback, type) else callback
+ cb_class = callback if isinstance(callback, type) else callback.__class__
+ if cb_class in [c.__class__ for c in self.callbacks]:
+ logger.warning(
+ f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. The current"
+ + "list of callbacks is\n:"
+ + self.callback_list
+ )
+ self.callbacks.append(cb)
+
+ def pop_callback(self, callback):
+ if isinstance(callback, type):
+ for cb in self.callbacks:
+ if isinstance(cb, callback):
+ self.callbacks.remove(cb)
+ return cb
+ else:
+ for cb in self.callbacks:
+ if cb == callback:
+ self.callbacks.remove(cb)
+ return cb
+
+ def remove_callback(self, callback):
+ if isinstance(callback, type):
+ for cb in self.callbacks:
+ if isinstance(cb, callback):
+ self.callbacks.remove(cb)
+ return
+ else:
+ self.callbacks.remove(callback)
+
+ @property
+ def callback_list(self):
+ return "\n".join(cb.__class__.__name__ for cb in self.callbacks)
+
+ def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ return self.call_event("on_init_end", args, state, control)
+
+ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ control.should_training_stop = False
+ return self.call_event("on_train_begin", args, state, control)
+
+ def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ return self.call_event("on_train_end", args, state, control)
+
+ def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ control.should_epoch_stop = False
+ return self.call_event("on_epoch_begin", args, state, control)
+
+ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ return self.call_event("on_epoch_end", args, state, control)
+
+ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ control.should_log = False
+ control.should_evaluate = False
+ control.should_save = False
+ return self.call_event("on_step_begin", args, state, control)
+
+ def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ return self.call_event("on_pre_optimizer_step", args, state, control)
+
+ def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ return self.call_event("on_optimizer_step", args, state, control)
+
+ def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ return self.call_event("on_substep_end", args, state, control)
+
+ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ return self.call_event("on_step_end", args, state, control)
+
+ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics):
+ control.should_evaluate = False
+ return self.call_event("on_evaluate", args, state, control, metrics=metrics)
+
+ def on_predict(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics):
+ return self.call_event("on_predict", args, state, control, metrics=metrics)
+
+ def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ control.should_save = False
+ return self.call_event("on_save", args, state, control)
+
+ def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs):
+ control.should_log = False
+ return self.call_event("on_log", args, state, control, logs=logs)
+
+ def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ return self.call_event("on_prediction_step", args, state, control)
+
+ def call_event(self, event, args, state, control, **kwargs):
+ for callback in self.callbacks:
+ result = getattr(callback, event)(
+ args,
+ state,
+ control,
+ model=self.model,
+ processing_class=self.processing_class,
+ optimizer=self.optimizer,
+ lr_scheduler=self.lr_scheduler,
+ train_dataloader=self.train_dataloader,
+ eval_dataloader=self.eval_dataloader,
+ **kwargs,
+ )
+ # A Callback can skip the return of `control` if it doesn't change it.
+ if result is not None:
+ control = result
+ return control
+
+
+class DefaultFlowCallback(TrainerCallback):
+ """
+ A [`TrainerCallback`] that handles the default flow of the training loop for logs, evaluation and checkpoints.
+ """
+
+ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ # Log
+ if state.global_step == 1 and args.logging_first_step:
+ control.should_log = True
+ if args.logging_strategy == IntervalStrategy.STEPS and state.global_step % state.logging_steps == 0:
+ control.should_log = True
+
+ # Evaluate
+ if (
+ args.eval_strategy == IntervalStrategy.STEPS
+ and state.global_step % state.eval_steps == 0
+ and args.eval_delay <= state.global_step
+ ):
+ control.should_evaluate = True
+
+ # Save
+ if (
+ args.save_strategy == SaveStrategy.STEPS
+ and state.save_steps > 0
+ and state.global_step % state.save_steps == 0
+ ):
+ control.should_save = True
+
+ # End training
+ if state.global_step >= state.max_steps:
+ control.should_training_stop = True
+ # Save the model at the end if we have a save strategy
+ if args.save_strategy == SaveStrategy.STEPS:
+ control.should_save = True
+
+ return control
+
+ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ # Log
+ if args.logging_strategy == IntervalStrategy.EPOCH:
+ control.should_log = True
+
+ # Evaluate
+ if args.eval_strategy == IntervalStrategy.EPOCH and args.eval_delay <= state.epoch:
+ control.should_evaluate = True
+
+ # Save
+ if args.save_strategy == SaveStrategy.EPOCH:
+ control.should_save = True
+
+ return control
+
+
+class ProgressCallback(TrainerCallback):
+ """
+ A [`TrainerCallback`] that displays the progress of training or evaluation.
+ You can modify `max_str_len` to control how long strings are truncated when logging.
+ """
+
+ def __init__(self, max_str_len: int = 100):
+ """
+ Initialize the callback with optional max_str_len parameter to control string truncation length.
+
+ Args:
+ max_str_len (`int`):
+ Maximum length of strings to display in logs.
+ Longer strings will be truncated with a message.
+ """
+ self.training_bar = None
+ self.prediction_bar = None
+ self.max_str_len = max_str_len
+
+ def on_train_begin(self, args, state, control, **kwargs):
+ if state.is_world_process_zero:
+ self.training_bar = tqdm(total=state.max_steps, dynamic_ncols=True)
+ self.current_step = 0
+
+ def on_step_end(self, args, state, control, **kwargs):
+ if state.is_world_process_zero:
+ self.training_bar.update(state.global_step - self.current_step)
+ self.current_step = state.global_step
+
+ def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
+ if state.is_world_process_zero and has_length(eval_dataloader):
+ if self.prediction_bar is None:
+ self.prediction_bar = tqdm(
+ total=len(eval_dataloader), leave=self.training_bar is None, dynamic_ncols=True
+ )
+ self.prediction_bar.update(1)
+
+ def on_evaluate(self, args, state, control, **kwargs):
+ if state.is_world_process_zero:
+ if self.prediction_bar is not None:
+ self.prediction_bar.close()
+ self.prediction_bar = None
+
+ def on_predict(self, args, state, control, **kwargs):
+ if state.is_world_process_zero:
+ if self.prediction_bar is not None:
+ self.prediction_bar.close()
+ self.prediction_bar = None
+
+ def on_log(self, args, state, control, logs=None, **kwargs):
+ if state.is_world_process_zero and self.training_bar is not None:
+ # make a shallow copy of logs so we can mutate the fields copied
+ # but avoid doing any value pickling.
+ shallow_logs = {}
+ for k, v in logs.items():
+ if isinstance(v, str) and len(v) > self.max_str_len:
+ shallow_logs[k] = (
+ f"[String too long to display, length: {len(v)} > {self.max_str_len}. "
+ "Consider increasing `max_str_len` if needed.]"
+ )
+ else:
+ shallow_logs[k] = v
+ _ = shallow_logs.pop("total_flos", None)
+ # round numbers so that it looks better in console
+ if "epoch" in shallow_logs:
+ shallow_logs["epoch"] = round(shallow_logs["epoch"], 2)
+ self.training_bar.write(str(shallow_logs))
+
+ def on_train_end(self, args, state, control, **kwargs):
+ if state.is_world_process_zero:
+ self.training_bar.close()
+ self.training_bar = None
+
+
+class PrinterCallback(TrainerCallback):
+ """
+ A bare [`TrainerCallback`] that just prints the logs.
+ """
+
+ def on_log(self, args, state, control, logs=None, **kwargs):
+ _ = logs.pop("total_flos", None)
+ if state.is_local_process_zero:
+ print(logs)
+
+
+class EarlyStoppingCallback(TrainerCallback, ExportableState):
+ """
+ A [`TrainerCallback`] that handles early stopping.
+
+ Args:
+ early_stopping_patience (`int`):
+ Use with `metric_for_best_model` to stop training when the specified metric worsens for
+ `early_stopping_patience` evaluation calls.
+ early_stopping_threshold(`float`, *optional*):
+ Use with TrainingArguments `metric_for_best_model` and `early_stopping_patience` to denote how much the
+ specified metric must improve to satisfy early stopping conditions. `
+
+ This callback depends on [`TrainingArguments`] argument *load_best_model_at_end* functionality to set best_metric
+ in [`TrainerState`]. Note that if the [`TrainingArguments`] argument *save_steps* differs from *eval_steps*, the
+ early stopping will not occur until the next save step.
+ """
+
+ def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
+ self.early_stopping_patience = early_stopping_patience
+ self.early_stopping_threshold = early_stopping_threshold
+ # early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
+ self.early_stopping_patience_counter = 0
+
+ def check_metric_value(self, args, state, control, metric_value):
+ # best_metric is set by code for load_best_model
+ operator = np.greater if args.greater_is_better else np.less
+ if state.best_metric is None or (
+ operator(metric_value, state.best_metric)
+ and abs(metric_value - state.best_metric) > self.early_stopping_threshold
+ ):
+ self.early_stopping_patience_counter = 0
+ else:
+ self.early_stopping_patience_counter += 1
+
+ def on_train_begin(self, args, state, control, **kwargs):
+ if not args.load_best_model_at_end:
+ logger.warning(
+ "Using EarlyStoppingCallback without load_best_model_at_end=True. "
+ "Once training is finished, the best model will not be loaded automatically."
+ )
+ assert args.metric_for_best_model is not None, (
+ "EarlyStoppingCallback requires metric_for_best_model to be defined"
+ )
+ assert args.eval_strategy != IntervalStrategy.NO, (
+ "EarlyStoppingCallback requires IntervalStrategy of steps or epoch"
+ )
+
+ def on_evaluate(self, args, state, control, metrics, **kwargs):
+ metric_to_check = args.metric_for_best_model
+ if not metric_to_check.startswith("eval_"):
+ metric_to_check = f"eval_{metric_to_check}"
+ metric_value = metrics.get(metric_to_check)
+
+ if metric_value is None:
+ logger.warning(
+ f"early stopping required metric_for_best_model, but did not find {metric_to_check} so early stopping"
+ " is disabled"
+ )
+ return
+
+ self.check_metric_value(args, state, control, metric_value)
+ if self.early_stopping_patience_counter >= self.early_stopping_patience:
+ control.should_training_stop = True
+
+ def state(self) -> dict:
+ return {
+ "args": {
+ "early_stopping_patience": self.early_stopping_patience,
+ "early_stopping_threshold": self.early_stopping_threshold,
+ },
+ "attributes": {
+ "early_stopping_patience_counter": self.early_stopping_patience_counter,
+ },
+ }
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/trainer_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/trainer_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e71367c70c742adc12087a7a82d3ddfe94dca6b
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/trainer_utils.py
@@ -0,0 +1,911 @@
+# Copyright 2020-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+PyTorch-independent utilities for the Trainer class.
+"""
+
+import copy
+import functools
+import gc
+import inspect
+import os
+import random
+import re
+import threading
+import time
+from typing import Any, Callable, NamedTuple, Optional, Union
+
+import numpy as np
+
+from .utils import (
+ ExplicitEnum,
+ is_psutil_available,
+ is_tf_available,
+ is_torch_available,
+ is_torch_cuda_available,
+ is_torch_hpu_available,
+ is_torch_mlu_available,
+ is_torch_mps_available,
+ is_torch_musa_available,
+ is_torch_npu_available,
+ is_torch_xla_available,
+ is_torch_xpu_available,
+ requires_backends,
+)
+
+
+if is_torch_available():
+ import torch
+
+
+def seed_worker(worker_id: int, num_workers: int, rank: int):
+ """
+ Helper function to set worker seed during Dataloader initialization.
+ """
+ init_seed = torch.initial_seed() % 2**32
+ worker_seed = num_workers * rank + init_seed
+ set_seed(worker_seed)
+
+
+def enable_full_determinism(seed: int, warn_only: bool = False):
+ """
+ Helper function for reproducible behavior during distributed training. See
+ - https://pytorch.org/docs/stable/notes/randomness.html for pytorch
+ - https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism for tensorflow
+ """
+ # set seed first
+ set_seed(seed)
+
+ if is_torch_available():
+ # Enable PyTorch deterministic mode. This potentially requires either the environment
+ # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
+ # depending on the CUDA version, so we set them both here
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+ # The environment variable required to enable deterministic mode on Ascend NPUs.
+ os.environ["ASCEND_LAUNCH_BLOCKING"] = "1"
+ os.environ["HCCL_DETERMINISTIC"] = "1"
+
+ os.environ["FLASH_ATTENTION_DETERMINISTIC"] = "1"
+ torch.use_deterministic_algorithms(True, warn_only=warn_only)
+
+ # Enable CUDNN deterministic mode
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ if is_tf_available():
+ import tensorflow as tf
+
+ tf.config.experimental.enable_op_determinism()
+
+
+def set_seed(seed: int, deterministic: bool = False):
+ """
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed).
+
+ Args:
+ seed (`int`):
+ The seed to set.
+ deterministic (`bool`, *optional*, defaults to `False`):
+ Whether to use deterministic algorithms where available. Can slow down training.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ if is_torch_available():
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ # ^^ safe to call this function even if cuda is not available
+ if deterministic:
+ torch.use_deterministic_algorithms(True)
+ if is_torch_mlu_available():
+ torch.mlu.manual_seed_all(seed)
+ if is_torch_musa_available():
+ torch.musa.manual_seed_all(seed)
+ if is_torch_npu_available():
+ torch.npu.manual_seed_all(seed)
+ if is_torch_hpu_available():
+ torch.hpu.manual_seed_all(seed)
+ if is_torch_xpu_available():
+ torch.xpu.manual_seed_all(seed)
+ if is_tf_available():
+ import tensorflow as tf
+
+ tf.random.set_seed(seed)
+ if deterministic:
+ tf.config.experimental.enable_op_determinism()
+
+
+def neftune_post_forward_hook(module, input, output):
+ """
+ Implements the NEFTune forward pass for the model using forward hooks. Note this works only for torch.nn.Embedding
+ layers. This method is slightly adapted from the original source code that can be found here:
+ https://github.com/neelsjain/NEFTune Simply add it to your model as follows:
+ ```python
+ model = ...
+ model.embed_tokens.neftune_noise_alpha = 0.1
+ model.embed_tokens.register_forward_hook(neftune_post_forward_hook)
+ ```
+ Args:
+ module (`torch.nn.Module`):
+ The embedding module where the hook is attached. Note that you need to set `module.neftune_noise_alpha` to
+ the desired noise alpha value.
+ input (`torch.Tensor`):
+ The input tensor to the model.
+ output (`torch.Tensor`):
+ The output tensor of the model (i.e. the embeddings).
+ """
+ if module.training:
+ dims = torch.tensor(output.size(1) * output.size(2))
+ mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
+ output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
+ return output
+
+
+class EvalPrediction:
+ """
+ Evaluation output (always contains labels), to be used to compute metrics.
+
+ Parameters:
+ predictions (`np.ndarray`): Predictions of the model.
+ label_ids (`np.ndarray`): Targets to be matched.
+ inputs (`np.ndarray`, *optional*): Input data passed to the model.
+ losses (`np.ndarray`, *optional*): Loss values computed during evaluation.
+ """
+
+ def __init__(
+ self,
+ predictions: Union[np.ndarray, tuple[np.ndarray]],
+ label_ids: Union[np.ndarray, tuple[np.ndarray]],
+ inputs: Optional[Union[np.ndarray, tuple[np.ndarray]]] = None,
+ losses: Optional[Union[np.ndarray, tuple[np.ndarray]]] = None,
+ ):
+ self.predictions = predictions
+ self.label_ids = label_ids
+ self.inputs = inputs
+ self.losses = losses
+ self.elements = (self.predictions, self.label_ids)
+ if self.inputs is not None:
+ self.elements += (self.inputs,)
+ if self.losses is not None:
+ self.elements += (self.losses,)
+
+ def __iter__(self):
+ return iter(self.elements)
+
+ def __getitem__(self, idx):
+ if idx < 0 or idx >= len(self.elements):
+ raise IndexError("tuple index out of range")
+ return self.elements[idx]
+
+
+class EvalLoopOutput(NamedTuple):
+ predictions: Union[np.ndarray, tuple[np.ndarray]]
+ label_ids: Optional[Union[np.ndarray, tuple[np.ndarray]]]
+ metrics: Optional[dict[str, float]]
+ num_samples: Optional[int]
+
+
+class PredictionOutput(NamedTuple):
+ predictions: Union[np.ndarray, tuple[np.ndarray]]
+ label_ids: Optional[Union[np.ndarray, tuple[np.ndarray]]]
+ metrics: Optional[dict[str, float]]
+
+
+class TrainOutput(NamedTuple):
+ global_step: int
+ training_loss: float
+ metrics: dict[str, float]
+
+
+PREFIX_CHECKPOINT_DIR = "checkpoint"
+_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
+
+
+def get_last_checkpoint(folder):
+ content = os.listdir(folder)
+ checkpoints = [
+ path
+ for path in content
+ if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path))
+ ]
+ if len(checkpoints) == 0:
+ return
+ return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
+
+
+class IntervalStrategy(ExplicitEnum):
+ NO = "no"
+ STEPS = "steps"
+ EPOCH = "epoch"
+
+
+class SaveStrategy(ExplicitEnum):
+ NO = "no"
+ STEPS = "steps"
+ EPOCH = "epoch"
+ BEST = "best"
+
+
+class EvaluationStrategy(ExplicitEnum):
+ NO = "no"
+ STEPS = "steps"
+ EPOCH = "epoch"
+
+
+class HubStrategy(ExplicitEnum):
+ END = "end"
+ EVERY_SAVE = "every_save"
+ CHECKPOINT = "checkpoint"
+ ALL_CHECKPOINTS = "all_checkpoints"
+
+
+class BestRun(NamedTuple):
+ """
+ The best run found by a hyperparameter search (see [`~Trainer.hyperparameter_search`]).
+
+ Parameters:
+ run_id (`str`):
+ The id of the best run (if models were saved, the corresponding checkpoint will be in the folder ending
+ with run-{run_id}).
+ objective (`float`):
+ The objective that was obtained for this run.
+ hyperparameters (`dict[str, Any]`):
+ The hyperparameters picked to get this run.
+ run_summary (`Optional[Any]`):
+ A summary of tuning experiments. `ray.tune.ExperimentAnalysis` object for Ray backend.
+ """
+
+ run_id: str
+ objective: Union[float, list[float]]
+ hyperparameters: dict[str, Any]
+ run_summary: Optional[Any] = None
+
+
+def default_compute_objective(metrics: dict[str, float]) -> float:
+ """
+ The default objective to maximize/minimize when doing an hyperparameter search. It is the evaluation loss if no
+ metrics are provided to the [`Trainer`], the sum of all metrics otherwise.
+
+ Args:
+ metrics (`dict[str, float]`): The metrics returned by the evaluate method.
+
+ Return:
+ `float`: The objective to minimize or maximize
+ """
+ metrics = copy.deepcopy(metrics)
+ loss = metrics.pop("eval_loss", None)
+ _ = metrics.pop("epoch", None)
+ # Remove speed metrics
+ speed_metrics = [
+ m for m in metrics if m.endswith("_runtime") or m.endswith("_per_second") or m.endswith("_compilation_time")
+ ]
+ for sm in speed_metrics:
+ _ = metrics.pop(sm, None)
+ return loss if len(metrics) == 0 else sum(metrics.values())
+
+
+def default_hp_space_optuna(trial) -> dict[str, float]:
+ from .integrations import is_optuna_available
+
+ assert is_optuna_available(), "This function needs Optuna installed: `pip install optuna`"
+ return {
+ "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
+ "num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5),
+ "seed": trial.suggest_int("seed", 1, 40),
+ "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16, 32, 64]),
+ }
+
+
+def default_hp_space_ray(trial) -> dict[str, Any]:
+ from .integrations import is_ray_tune_available
+
+ assert is_ray_tune_available(), "This function needs ray installed: `pip install ray[tune]`"
+ from ray import tune
+
+ return {
+ "learning_rate": tune.loguniform(1e-6, 1e-4),
+ "num_train_epochs": tune.choice(list(range(1, 6))),
+ "seed": tune.uniform(1, 40),
+ "per_device_train_batch_size": tune.choice([4, 8, 16, 32, 64]),
+ }
+
+
+def default_hp_space_sigopt(trial):
+ return [
+ {"bounds": {"min": 1e-6, "max": 1e-4}, "name": "learning_rate", "type": "double", "transformation": "log"},
+ {"bounds": {"min": 1, "max": 6}, "name": "num_train_epochs", "type": "int"},
+ {"bounds": {"min": 1, "max": 40}, "name": "seed", "type": "int"},
+ {
+ "categorical_values": ["4", "8", "16", "32", "64"],
+ "name": "per_device_train_batch_size",
+ "type": "categorical",
+ },
+ ]
+
+
+def default_hp_space_wandb(trial) -> dict[str, Any]:
+ from .integrations import is_wandb_available
+
+ if not is_wandb_available():
+ raise ImportError("This function needs wandb installed: `pip install wandb`")
+
+ return {
+ "method": "random",
+ "metric": {"name": "objective", "goal": "minimize"},
+ "parameters": {
+ "learning_rate": {"distribution": "uniform", "min": 1e-6, "max": 1e-4},
+ "num_train_epochs": {"distribution": "int_uniform", "min": 1, "max": 6},
+ "seed": {"distribution": "int_uniform", "min": 1, "max": 40},
+ "per_device_train_batch_size": {"values": [4, 8, 16, 32, 64]},
+ },
+ }
+
+
+class HPSearchBackend(ExplicitEnum):
+ OPTUNA = "optuna"
+ RAY = "ray"
+ SIGOPT = "sigopt"
+ WANDB = "wandb"
+
+
+def is_main_process(local_rank):
+ """
+ Whether or not the current process is the local process, based on `xr.global_ordinal()` (for TPUs) first, then on
+ `local_rank`.
+ """
+ if is_torch_xla_available():
+ import torch_xla.runtime as xr
+
+ return xr.global_ordinal() == 0
+ return local_rank in [-1, 0]
+
+
+def total_processes_number(local_rank):
+ """
+ Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
+ """
+ if is_torch_xla_available():
+ import torch_xla.runtime as xr
+
+ return xr.world_size()
+ elif local_rank != -1 and is_torch_available():
+ import torch
+
+ return torch.distributed.get_world_size()
+ return 1
+
+
+def speed_metrics(split, start_time, num_samples=None, num_steps=None, num_tokens=None):
+ """
+ Measure and return speed performance metrics.
+
+ This function requires a time snapshot `start_time` before the operation to be measured starts and this function
+ should be run immediately after the operation to be measured has completed.
+
+ Args:
+ - split: name to prefix metric (like train, eval, test...)
+ - start_time: operation start time
+ - num_samples: number of samples processed
+ - num_steps: number of steps processed
+ - num_tokens: number of tokens processed
+ """
+ runtime = time.time() - start_time
+ result = {f"{split}_runtime": round(runtime, 4)}
+ if runtime == 0:
+ return result
+ if num_samples is not None:
+ samples_per_second = num_samples / runtime
+ result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
+ if num_steps is not None:
+ steps_per_second = num_steps / runtime
+ result[f"{split}_steps_per_second"] = round(steps_per_second, 3)
+ if num_tokens is not None:
+ tokens_per_second = num_tokens / runtime
+ result[f"{split}_tokens_per_second"] = round(tokens_per_second, 3)
+ return result
+
+
+class SchedulerType(ExplicitEnum):
+ """
+ Scheduler names for the parameter `lr_scheduler_type` in [`TrainingArguments`].
+ By default, it uses "linear". Internally, this retrieves `get_linear_schedule_with_warmup` scheduler from [`Trainer`].
+ Scheduler types:
+ - "linear" = [`get_linear_schedule_with_warmup`]
+ - "cosine" = [`get_cosine_schedule_with_warmup`]
+ - "cosine_with_restarts" = [`get_cosine_with_hard_restarts_schedule_with_warmup`]
+ - "polynomial" = [`get_polynomial_decay_schedule_with_warmup`]
+ - "constant" = [`get_constant_schedule`]
+ - "constant_with_warmup" = [`get_constant_schedule_with_warmup`]
+ - "inverse_sqrt" = [`get_inverse_sqrt_schedule`]
+ - "reduce_lr_on_plateau" = [`get_reduce_on_plateau_schedule`]
+ - "cosine_with_min_lr" = [`get_cosine_with_min_lr_schedule_with_warmup`]
+ - "cosine_warmup_with_min_lr" = [`get_cosine_with_min_lr_schedule_with_warmup_lr_rate`]
+ - "warmup_stable_decay" = [`get_wsd_schedule`]
+ """
+
+ LINEAR = "linear"
+ COSINE = "cosine"
+ COSINE_WITH_RESTARTS = "cosine_with_restarts"
+ POLYNOMIAL = "polynomial"
+ CONSTANT = "constant"
+ CONSTANT_WITH_WARMUP = "constant_with_warmup"
+ INVERSE_SQRT = "inverse_sqrt"
+ REDUCE_ON_PLATEAU = "reduce_lr_on_plateau"
+ COSINE_WITH_MIN_LR = "cosine_with_min_lr"
+ COSINE_WARMUP_WITH_MIN_LR = "cosine_warmup_with_min_lr"
+ WARMUP_STABLE_DECAY = "warmup_stable_decay"
+
+
+class TrainerMemoryTracker:
+ """
+ A helper class that tracks cpu and gpu memory.
+
+ This class will silently skip unless `psutil` is available. Install with `pip install psutil`.
+
+ When a stage completes, it can pass metrics dict to update with the memory metrics gathered during this stage.
+
+ Example :
+
+ ```python
+ self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
+ self._memory_tracker.start()
+ # code ...
+ metrics = {"train_runtime": 10.5}
+ self._memory_tracker.stop_and_update_metrics(metrics)
+ ```
+
+ At the moment GPU tracking is only for `pytorch`, but can be extended to support `tensorflow`.
+
+ To understand this class' intricacies please read the documentation of [`~Trainer.log_metrics`].
+ """
+
+ # map trainer methods to metrics prefix
+ stages = {
+ "__init__": "init",
+ "train": "train",
+ "_inner_training_loop": "train",
+ "evaluate": "eval",
+ "predict": "test",
+ }
+
+ def __init__(self, skip_memory_metrics=False):
+ self.skip_memory_metrics = skip_memory_metrics
+
+ if not is_psutil_available():
+ # soft dependency on psutil
+ self.skip_memory_metrics = True
+
+ if self.skip_memory_metrics:
+ return
+
+ import psutil
+
+ if is_torch_cuda_available() or is_torch_mlu_available() or is_torch_musa_available():
+ import torch
+
+ self.torch = torch
+ self.gpu = {}
+ elif is_torch_mps_available():
+ import torch
+
+ self.torch = torch
+ self.gpu = {}
+ elif is_torch_xpu_available():
+ import torch
+
+ self.torch = torch
+ self.gpu = {}
+ elif is_torch_npu_available():
+ import torch
+
+ self.torch = torch
+ self.gpu = {}
+ elif is_torch_hpu_available():
+ import torch
+
+ self.torch = torch
+ self.gpu = {}
+ else:
+ self.torch = None
+
+ self.process = psutil.Process()
+
+ self.cur_stage = None
+ self.cpu = {}
+ self.init_reported = False
+
+ def derive_stage(self):
+ """derives the stage/caller name automatically"""
+ caller = inspect.currentframe().f_back.f_back.f_code.co_name
+ if caller in self.stages:
+ return self.stages[caller]
+ else:
+ raise ValueError(
+ f"was called from {caller}, but only expect to be called from one of {self.stages.keys()}"
+ )
+
+ def cpu_mem_used(self):
+ """get resident set size memory for the current process"""
+ return self.process.memory_info().rss
+
+ def peak_monitor_func(self):
+ self.cpu_mem_used_peak = -1
+
+ while True:
+ self.cpu_mem_used_peak = max(self.cpu_mem_used(), self.cpu_mem_used_peak)
+
+ # can't sleep or will not catch the peak right (this comment is here on purpose)
+ # time.sleep(0.001) # 1msec
+
+ if not self.peak_monitoring:
+ break
+
+ def start(self):
+ """start tracking for the caller's stage"""
+ if self.skip_memory_metrics:
+ return
+
+ stage = self.derive_stage()
+ # deal with nested calls of eval during train - simply ignore those
+ if self.cur_stage is not None and self.cur_stage != stage:
+ return
+
+ self.cur_stage = stage
+
+ gc.collect()
+
+ if self.torch is not None:
+ if torch.cuda.is_available():
+ self.torch.cuda.reset_peak_memory_stats()
+ self.torch.cuda.empty_cache()
+ elif is_torch_mlu_available():
+ self.torch.mlu.reset_peak_memory_stats()
+ self.torch.mlu.empty_cache()
+ elif is_torch_musa_available():
+ self.torch.musa.reset_peak_memory_stats()
+ self.torch.musa.empty_cache()
+ elif is_torch_xpu_available():
+ self.torch.xpu.reset_peak_memory_stats()
+ self.torch.xpu.empty_cache()
+ elif is_torch_npu_available():
+ self.torch.npu.reset_peak_memory_stats()
+ self.torch.npu.empty_cache()
+ elif is_torch_hpu_available():
+ self.torch.hpu.reset_peak_memory_stats()
+ # not available on hpu as it reserves all device memory for the current process
+ # self.torch.hpu.empty_cache()
+ elif is_torch_mps_available():
+ self.torch.mps.empty_cache()
+
+ # gpu
+ if self.torch is not None:
+ if torch.cuda.is_available():
+ self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated()
+ elif is_torch_mlu_available():
+ self.gpu_mem_used_at_start = self.torch.mlu.memory_allocated()
+ elif is_torch_musa_available():
+ self.gpu_mem_used_at_start = self.torch.musa.memory_allocated()
+ elif is_torch_xpu_available():
+ self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated()
+ elif is_torch_npu_available():
+ self.gpu_mem_used_at_start = self.torch.npu.memory_allocated()
+ elif is_torch_hpu_available():
+ self.gpu_mem_used_at_start = self.torch.hpu.memory_allocated()
+ elif is_torch_mps_available():
+ self.gpu_mem_used_at_start = self.torch.mps.current_allocated_memory()
+
+ # cpu
+ self.cpu_mem_used_at_start = self.cpu_mem_used()
+
+ self.peak_monitoring = True
+ peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
+ peak_monitor_thread.daemon = True
+ peak_monitor_thread.start()
+
+ def stop(self, stage):
+ """stop tracking for the passed stage"""
+
+ # deal with nested calls of eval during train - simply ignore those
+ if self.cur_stage is not None and self.cur_stage != stage:
+ return
+
+ # this sends a signal to peak_monitor_func to complete its loop
+ self.peak_monitoring = False
+
+ # first ensure all objects get collected and their memory is freed
+ gc.collect()
+
+ if self.torch is not None:
+ if torch.cuda.is_available():
+ self.torch.cuda.empty_cache()
+ elif is_torch_mlu_available():
+ self.torch.mlu.empty_cache()
+ elif is_torch_musa_available():
+ self.torch.musa.empty_cache()
+ elif is_torch_xpu_available():
+ self.torch.xpu.empty_cache()
+ elif is_torch_npu_available():
+ self.torch.npu.empty_cache()
+ elif is_torch_hpu_available():
+ # not available on hpu as it reserves all device memory for the current process
+ # self.torch.npu.empty_cache()
+ pass
+ elif is_torch_mps_available():
+ self.torch.mps.empty_cache()
+
+ # concepts:
+ # - alloc_delta: the difference of allocated memory between the end and the start
+ # - peaked_delta: the difference between the peak memory and the current memory
+ # in order to know how much memory the measured code consumed one needs to sum these two
+
+ # gpu
+ if self.torch is not None:
+ if torch.cuda.is_available():
+ self.gpu_mem_used_now = self.torch.cuda.memory_allocated()
+ self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated()
+ elif is_torch_mlu_available():
+ self.gpu_mem_used_now = self.torch.mlu.memory_allocated()
+ self.gpu_mem_used_peak = self.torch.mlu.max_memory_allocated()
+ elif is_torch_musa_available():
+ self.gpu_mem_used_now = self.torch.musa.memory_allocated()
+ self.gpu_mem_used_peak = self.torch.musa.max_memory_allocated()
+ elif is_torch_xpu_available():
+ self.gpu_mem_used_now = self.torch.xpu.memory_allocated()
+ self.gpu_mem_used_peak = self.torch.xpu.max_memory_allocated()
+ elif is_torch_npu_available():
+ self.gpu_mem_used_now = self.torch.npu.memory_allocated()
+ self.gpu_mem_used_peak = self.torch.npu.max_memory_allocated()
+ elif is_torch_hpu_available():
+ self.gpu_mem_used_now = self.torch.hpu.memory_allocated()
+ self.gpu_mem_used_peak = self.torch.hpu.max_memory_allocated()
+ elif is_torch_mps_available():
+ self.gpu_mem_used_now = self.torch.mps.current_allocated_memory()
+ # self.torch.mps.max_memory_allocated() does not exist yet
+ self.gpu_mem_used_peak = None
+
+ else:
+ raise ValueError("No available GPU device found!")
+
+ self.gpu[self.cur_stage] = {
+ "begin": self.gpu_mem_used_at_start,
+ "end": self.gpu_mem_used_now,
+ "alloc": (self.gpu_mem_used_now - self.gpu_mem_used_at_start),
+ }
+ if self.gpu_mem_used_peak is not None:
+ self.gpu[self.cur_stage]["peaked"] = max(0, self.gpu_mem_used_peak - self.gpu_mem_used_now)
+ else:
+ self.gpu[self.cur_stage]["peaked"] = "Not available"
+
+ # cpu
+ self.cpu_mem_used_now = self.cpu_mem_used()
+ self.cpu[self.cur_stage] = {
+ "begin": self.cpu_mem_used_at_start,
+ "end": self.cpu_mem_used_now,
+ "alloc": (self.cpu_mem_used_now - self.cpu_mem_used_at_start),
+ "peaked": max(0, self.cpu_mem_used_peak - self.cpu_mem_used_now),
+ }
+
+ # reset - cycle finished
+ self.cur_stage = None
+
+ def update_metrics(self, stage, metrics):
+ """updates the metrics"""
+ if self.skip_memory_metrics:
+ return
+
+ # deal with nested calls of eval during train - simply ignore those
+ if self.cur_stage is not None and self.cur_stage != stage:
+ return
+
+ # since we don't have a way to return init metrics, we push them into the first of train/val/predict
+ stages = [stage]
+ if not self.init_reported:
+ stages.insert(0, "init")
+ self.init_reported = True
+
+ for stage in stages:
+ for t in ["alloc", "peaked"]:
+ if stage in self.cpu and t in self.cpu[stage]:
+ metrics[f"{stage}_mem_cpu_{t}_delta"] = self.cpu[stage][t]
+ if self.torch is not None and stage in self.gpu and t in self.gpu[stage]:
+ metrics[f"{stage}_mem_gpu_{t}_delta"] = self.gpu[stage][t]
+ # if we need additional debug info, enable the following
+ # for t in ["begin", "end"]:
+ # if stage in self.cpu and t in self.cpu[stage]:
+ # metrics[f"{stage}_mem_cpu_{t}"] = self.cpu[stage][t]
+ # if self.torch is not None and stage in self.gpu and t in self.gpu[stage]:
+ # metrics[f"{stage}_mem_gpu_{t}"] = self.gpu[stage][t]
+
+ # since memory can be allocated before init, and it might be difficult to track overall
+ # memory usage, in particular for GPU, let's report memory usage at the point init was called
+ if stages[0] == "init":
+ metrics["before_init_mem_cpu"] = self.cpu["init"]["begin"]
+ if self.torch is not None:
+ metrics["before_init_mem_gpu"] = self.gpu["init"]["begin"]
+ # if we also wanted to report any additional memory allocations in between init and
+ # whatever the next stage was we could also report this:
+ # if self.cpu["init"]["end"] != self.cpu[stage]["begin"]:
+ # metrics[f"after_init_mem_cpu_delta"] = self.cpu[stage]["begin"] - self.cpu["init"]["end"]
+ # if self.torch is not None and self.gpu["init"]["end"] != self.gpu[stage]["begin"]:
+ # metrics[f"after_init_mem_gpu_delta"] = self.gpu[stage]["begin"] - self.gpu["init"]["end"]
+
+ def stop_and_update_metrics(self, metrics=None):
+ """combine stop and metrics update in one call for simpler code"""
+ if self.skip_memory_metrics:
+ return
+
+ stage = self.derive_stage()
+ self.stop(stage)
+
+ # init doesn't have metrics to update so we just save that data for later stages to retrieve
+ if metrics is not None:
+ self.update_metrics(stage, metrics)
+
+
+def has_length(dataset):
+ """
+ Checks if the dataset implements __len__() and it doesn't raise an error
+ """
+ try:
+ return len(dataset) is not None
+ except TypeError:
+ # TypeError: len() of unsized object
+ return False
+ except AttributeError:
+ # Ray DataSets raises an AttributeError: https://github.com/ray-project/ray/blob/master/python/ray/data/dataset.py#L5616
+ return False
+
+
+def denumpify_detensorize(metrics):
+ """
+ Recursively calls `.item()` on the element of the dictionary passed
+ """
+ if isinstance(metrics, (list, tuple)):
+ return type(metrics)(denumpify_detensorize(m) for m in metrics)
+ elif isinstance(metrics, dict):
+ return type(metrics)({k: denumpify_detensorize(v) for k, v in metrics.items()})
+ elif isinstance(metrics, np.generic):
+ return metrics.item()
+ elif is_torch_available() and isinstance(metrics, torch.Tensor) and metrics.numel() == 1:
+ return metrics.item()
+ return metrics
+
+
+def number_of_arguments(func):
+ """
+ Return the number of arguments of the passed function, even if it's a partial function.
+ """
+ if isinstance(func, functools.partial):
+ total_args = len(inspect.signature(func.func).parameters)
+ return total_args - len(func.args) - len(func.keywords)
+ return len(inspect.signature(func).parameters)
+
+
+def find_executable_batch_size(
+ function: Optional[Callable] = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False
+):
+ """
+ Args:
+ A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
+ CUDNN, the batch size is multiplied by 0.9 and passed to `function`. `function` must take in a `batch_size` parameter as
+ its first argument.
+ function (`Callable`, *optional*)
+ A function to wrap
+ starting_batch_size (`int`, *optional*)
+ The batch size to try and fit into memory
+ auto_find_batch_size (`bool`, *optional*)
+ If False, will just execute `function`
+ """
+ if function is None:
+ return functools.partial(
+ find_executable_batch_size,
+ starting_batch_size=starting_batch_size,
+ auto_find_batch_size=auto_find_batch_size,
+ )
+
+ if auto_find_batch_size:
+ requires_backends(find_executable_batch_size, "accelerate")
+ from accelerate.utils import find_executable_batch_size as accelerate_find_executable_batch_size
+
+ return accelerate_find_executable_batch_size(function=function, starting_batch_size=starting_batch_size)
+
+ return functools.partial(function, batch_size=starting_batch_size)
+
+
+class FSDPOption(ExplicitEnum):
+ FULL_SHARD = "full_shard"
+ SHARD_GRAD_OP = "shard_grad_op"
+ NO_SHARD = "no_shard"
+ HYBRID_SHARD = "hybrid_shard"
+ HYBRID_SHARD_ZERO2 = "hybrid_shard_zero2"
+ OFFLOAD = "offload"
+ AUTO_WRAP = "auto_wrap"
+
+
+class RemoveColumnsCollator:
+ """Wrap the data collator to remove unused columns before they are passed to the collator."""
+
+ def __init__(
+ self,
+ data_collator,
+ signature_columns,
+ logger=None,
+ model_name: Optional[str] = None,
+ description: Optional[str] = None,
+ ):
+ self.data_collator = data_collator
+ self.signature_columns = signature_columns
+ self.logger = logger
+ self.description = description
+ self.model_name = model_name
+ self.message_logged = False
+
+ def _remove_columns(self, feature: dict) -> dict:
+ if not isinstance(feature, dict):
+ return feature
+ if not self.message_logged and self.logger and self.model_name:
+ ignored_columns = list(set(feature.keys()) - set(self.signature_columns))
+ if len(ignored_columns) > 0:
+ dset_description = "" if self.description is None else f"in the {self.description} set"
+ self.logger.info(
+ f"The following columns {dset_description} don't have a corresponding argument in "
+ f"`{self.model_name}.forward` and have been ignored: {', '.join(ignored_columns)}."
+ f" If {', '.join(ignored_columns)} are not expected by `{self.model_name}.forward`, "
+ " you can safely ignore this message."
+ )
+ self.message_logged = True
+ return {k: v for k, v in feature.items() if k in self.signature_columns}
+
+ def __call__(self, features: list[dict]):
+ features = [self._remove_columns(feature) for feature in features]
+ return self.data_collator(features)
+
+
+def check_target_module_exists(optim_target_modules, key: str, return_is_regex: bool = False):
+ """A helper method to check if the passed module's key name matches any of the target modules in the optim_target_modules.
+
+ Args:
+ optim_target_modules (`Union[str, list[str]]`):
+ A list of strings to try to match. Can be also a full string.
+ key (`str`):
+ A key to search any matches in optim_target_modules
+ return_is_regex (`bool`):
+ If set to `True`, the method will return whether the passed `optim_target_modules`
+ is a regex or not.
+
+ Returns:
+ `bool` : True of match object if key matches any target modules from config, False or
+ None if no match found
+ `bool` : If the matched target module is a regex to silence out the warnings in Trainer
+ for extra modules being found (only if `target_module_found=True` for an array of regex).
+ """
+ target_module_found = False
+ is_regex = False
+
+ if isinstance(optim_target_modules, str):
+ target_module_found = bool(re.fullmatch(optim_target_modules, key))
+ is_regex = optim_target_modules != key
+ elif key in optim_target_modules: # from here, target_module_found must be a list of str
+ # this module is specified directly in target_modules
+ target_module_found = True
+ elif any(target_key in key for target_key in optim_target_modules):
+ target_module_found = True
+ elif any(bool(re.fullmatch(optim_target_module, key)) for optim_target_module in optim_target_modules):
+ target_module_found = True
+ is_regex = True
+
+ if return_is_regex:
+ return target_module_found, is_regex
+
+ return target_module_found
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/training_args.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/training_args.py
new file mode 100644
index 0000000000000000000000000000000000000000..79fdfe6c0c2e824eae03d09b5156312e0fe587c3
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/training_args.py
@@ -0,0 +1,3167 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import contextlib
+import json
+import math
+import os
+import warnings
+from dataclasses import asdict, dataclass, field, fields
+from datetime import timedelta
+from enum import Enum
+from functools import cached_property
+from pathlib import Path
+from typing import Any, Optional, Union
+
+from huggingface_hub import get_full_repo_name
+
+from .debug_utils import DebugOption
+from .trainer_utils import (
+ EvaluationStrategy,
+ FSDPOption,
+ HubStrategy,
+ IntervalStrategy,
+ SaveStrategy,
+ SchedulerType,
+)
+from .utils import (
+ ACCELERATE_MIN_VERSION,
+ ExplicitEnum,
+ is_accelerate_available,
+ is_apex_available,
+ is_ipex_available,
+ is_sagemaker_dp_enabled,
+ is_sagemaker_mp_enabled,
+ is_torch_available,
+ is_torch_bf16_gpu_available,
+ is_torch_cuda_available,
+ is_torch_hpu_available,
+ is_torch_mlu_available,
+ is_torch_mps_available,
+ is_torch_musa_available,
+ is_torch_neuroncore_available,
+ is_torch_npu_available,
+ is_torch_tf32_available,
+ is_torch_xla_available,
+ is_torch_xpu_available,
+ logging,
+ requires_backends,
+)
+from .utils.generic import strtobool
+from .utils.import_utils import is_optimum_neuron_available
+
+
+logger = logging.get_logger(__name__)
+log_levels = logging.get_log_levels_dict().copy()
+trainer_log_levels = dict(**log_levels, passive=-1)
+
+if is_torch_available():
+ import torch
+ import torch.distributed as dist
+
+if is_accelerate_available():
+ from accelerate.state import AcceleratorState, PartialState
+ from accelerate.utils import DistributedType
+
+ from .trainer_pt_utils import AcceleratorConfig
+
+if is_accelerate_available("1.10.1"):
+ from accelerate.parallelism_config import ParallelismConfig
+else:
+ ParallelismConfig = Any
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+if is_torch_neuroncore_available(check_device=False):
+ # torchrun support
+ # https://github.com/pytorch/xla/pull/3609
+ if os.environ.get("TORCHELASTIC_RUN_ID"):
+ if is_optimum_neuron_available():
+ logger.info(
+ "Make sure that you are performing the training with the NeuronTrainer from optimum[neuron], this "
+ "will fail otherwise."
+ )
+ else:
+ logger.warning(
+ "Please use the NeuronTrainer from optimum[neuron] instead of the Transformers library to perform "
+ "training on AWS Trainium instances. More information here: "
+ "https://github.com/huggingface/optimum-neuron"
+ )
+ import torch_xla.distributed.xla_backend as xbn
+
+ if not isinstance(dist.group.WORLD, xbn.ProcessGroupXla):
+ dist.init_process_group(backend="xla")
+ if not isinstance(dist.group.WORLD, xbn.ProcessGroupXla):
+ raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.")
+
+
+if is_sagemaker_mp_enabled():
+ import smdistributed.modelparallel.torch as smp
+
+ smp.init()
+
+
+def default_logdir() -> str:
+ """
+ Same default as PyTorch
+ """
+ import socket
+ from datetime import datetime
+
+ current_time = datetime.now().strftime("%b%d_%H-%M-%S")
+ return os.path.join("runs", current_time + "_" + socket.gethostname())
+
+
+def get_int_from_env(env_keys, default):
+ """Returns the first positive env value found in the `env_keys` list or the default."""
+ for e in env_keys:
+ val = int(os.environ.get(e, "-1"))
+ if val >= 0:
+ return val
+ return default
+
+
+def get_xla_device_type(device: "torch.device") -> Optional[str]:
+ """
+ Returns the xla device type (CPU|GPU|TPU) or None if the device is a non-xla device.
+ """
+ if is_torch_xla_available():
+ if device.type == "cpu":
+ return "CPU"
+ return xm.xla_real_devices([device])[0].split(":")[0]
+ return None
+
+
+class OptimizerNames(ExplicitEnum):
+ """
+ Stores the acceptable string identifiers for optimizers.
+ """
+
+ ADAMW_TORCH = "adamw_torch"
+ ADAMW_TORCH_FUSED = "adamw_torch_fused"
+ ADAMW_TORCH_XLA = "adamw_torch_xla"
+ ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused"
+ ADAMW_APEX_FUSED = "adamw_apex_fused"
+ ADAFACTOR = "adafactor"
+ ADAMW_ANYPRECISION = "adamw_anyprecision"
+ ADAMW_TORCH_4BIT = "adamw_torch_4bit"
+ ADAMW_TORCH_8BIT = "adamw_torch_8bit"
+ ADEMAMIX = "ademamix"
+ SGD = "sgd"
+ ADAGRAD = "adagrad"
+ ADAMW_BNB = "adamw_bnb_8bit"
+ ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit
+ ADEMAMIX_8BIT = "ademamix_8bit"
+ LION_8BIT = "lion_8bit"
+ LION = "lion_32bit"
+ PAGED_ADAMW = "paged_adamw_32bit"
+ PAGED_ADAMW_8BIT = "paged_adamw_8bit"
+ PAGED_ADEMAMIX = "paged_ademamix_32bit"
+ PAGED_ADEMAMIX_8BIT = "paged_ademamix_8bit"
+ PAGED_LION = "paged_lion_32bit"
+ PAGED_LION_8BIT = "paged_lion_8bit"
+ RMSPROP = "rmsprop"
+ RMSPROP_BNB = "rmsprop_bnb"
+ RMSPROP_8BIT = "rmsprop_bnb_8bit"
+ RMSPROP_32BIT = "rmsprop_bnb_32bit"
+ GALORE_ADAMW = "galore_adamw"
+ GALORE_ADAMW_8BIT = "galore_adamw_8bit"
+ GALORE_ADAFACTOR = "galore_adafactor"
+ GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
+ GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
+ GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
+ LOMO = "lomo"
+ ADALOMO = "adalomo"
+ GROKADAMW = "grokadamw"
+ SCHEDULE_FREE_RADAM = "schedule_free_radam"
+ SCHEDULE_FREE_ADAMW = "schedule_free_adamw"
+ SCHEDULE_FREE_SGD = "schedule_free_sgd"
+ APOLLO_ADAMW = "apollo_adamw"
+ APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise"
+ STABLE_ADAMW = "stable_adamw"
+
+
+def _convert_str_dict(passed_value: dict):
+ "Safely checks that a passed value is a dictionary and converts any string values to their appropriate types."
+ for key, value in passed_value.items():
+ if isinstance(value, dict):
+ passed_value[key] = _convert_str_dict(value)
+ elif isinstance(value, str):
+ # First check for bool and convert
+ if value.lower() in ("true", "false"):
+ passed_value[key] = value.lower() == "true"
+ # Check for digit
+ elif value.isdigit():
+ passed_value[key] = int(value)
+ elif value.replace(".", "", 1).isdigit():
+ passed_value[key] = float(value)
+
+ return passed_value
+
+
+# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
+@dataclass
+class TrainingArguments:
+ """
+ TrainingArguments is the subset of the arguments we use in our example scripts **which relate to the training loop
+ itself**.
+
+ Using [`HfArgumentParser`] we can turn this class into
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
+ command line.
+
+ Parameters:
+ output_dir (`str`, *optional*, defaults to `"trainer_output"`):
+ The output directory where the model predictions and checkpoints will be written.
+ overwrite_output_dir (`bool`, *optional*, defaults to `False`):
+ If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir`
+ points to a checkpoint directory.
+ do_train (`bool`, *optional*, defaults to `False`):
+ Whether to run training or not. This argument is not directly used by [`Trainer`], it's intended to be used
+ by your training/evaluation scripts instead. See the [example
+ scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
+ do_eval (`bool`, *optional*):
+ Whether to run evaluation on the validation set or not. Will be set to `True` if `eval_strategy` is
+ different from `"no"`. This argument is not directly used by [`Trainer`], it's intended to be used by your
+ training/evaluation scripts instead. See the [example
+ scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
+ do_predict (`bool`, *optional*, defaults to `False`):
+ Whether to run predictions on the test set or not. This argument is not directly used by [`Trainer`], it's
+ intended to be used by your training/evaluation scripts instead. See the [example
+ scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
+ eval_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`):
+ The evaluation strategy to adopt during training. Possible values are:
+
+ - `"no"`: No evaluation is done during training.
+ - `"steps"`: Evaluation is done (and logged) every `eval_steps`.
+ - `"epoch"`: Evaluation is done at the end of each epoch.
+
+ prediction_loss_only (`bool`, *optional*, defaults to `False`):
+ When performing evaluation and generating predictions, only returns the loss.
+ per_device_train_batch_size (`int`, *optional*, defaults to 8):
+ The batch size *per device*. The **global batch size** is computed as:
+ `per_device_train_batch_size * number_of_devices` in multi-GPU or distributed setups.
+ per_device_eval_batch_size (`int`, *optional*, defaults to 8):
+ The batch size per device accelerator core/CPU for evaluation.
+ gradient_accumulation_steps (`int`, *optional*, defaults to 1):
+ Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
+
+
+
+ When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging,
+ evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training examples.
+
+
+
+ eval_accumulation_steps (`int`, *optional*):
+ Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. If
+ left unset, the whole predictions are accumulated on the device accelerator before being moved to the CPU (faster but
+ requires more memory).
+ eval_delay (`float`, *optional*):
+ Number of epochs or steps to wait for before the first evaluation can be performed, depending on the
+ eval_strategy.
+ torch_empty_cache_steps (`int`, *optional*):
+ Number of steps to wait before calling `torch..empty_cache()`. If left unset or set to None, cache will not be emptied.
+
+
+
+ This can help avoid CUDA out-of-memory errors by lowering peak VRAM usage at a cost of about [10% slower performance](https://github.com/huggingface/transformers/issues/31372).
+
+
+
+ learning_rate (`float`, *optional*, defaults to 5e-5):
+ The initial learning rate for [`AdamW`] optimizer.
+ weight_decay (`float`, *optional*, defaults to 0):
+ The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights in [`AdamW`]
+ optimizer.
+ adam_beta1 (`float`, *optional*, defaults to 0.9):
+ The beta1 hyperparameter for the [`AdamW`] optimizer.
+ adam_beta2 (`float`, *optional*, defaults to 0.999):
+ The beta2 hyperparameter for the [`AdamW`] optimizer.
+ adam_epsilon (`float`, *optional*, defaults to 1e-8):
+ The epsilon hyperparameter for the [`AdamW`] optimizer.
+ max_grad_norm (`float`, *optional*, defaults to 1.0):
+ Maximum gradient norm (for gradient clipping).
+ num_train_epochs(`float`, *optional*, defaults to 3.0):
+ Total number of training epochs to perform (if not an integer, will perform the decimal part percents of
+ the last epoch before stopping training).
+ max_steps (`int`, *optional*, defaults to -1):
+ If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`.
+ For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until
+ `max_steps` is reached.
+ lr_scheduler_type (`str` or [`SchedulerType`], *optional*, defaults to `"linear"`):
+ The scheduler type to use. See the documentation of [`SchedulerType`] for all possible values.
+ lr_scheduler_kwargs (`dict` or `str`, *optional*, defaults to `None`):
+ The extra arguments for the lr_scheduler. See the documentation of each scheduler for possible values.
+ warmup_ratio (`float`, *optional*, defaults to 0.0):
+ Ratio of total training steps used for a linear warmup from 0 to `learning_rate`.
+ warmup_steps (`int`, *optional*, defaults to 0):
+ Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`.
+ log_level (`str`, *optional*, defaults to `passive`):
+ Logger log level to use on the main process. Possible choices are the log levels as strings: 'debug',
+ 'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and keeps the
+ current log level for the Transformers library (which will be `"warning"` by default).
+ log_level_replica (`str`, *optional*, defaults to `"warning"`):
+ Logger log level to use on replicas. Same choices as `log_level`"
+ log_on_each_node (`bool`, *optional*, defaults to `True`):
+ In multinode distributed training, whether to log using `log_level` once per node, or only on the main
+ node.
+ logging_dir (`str`, *optional*):
+ [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to
+ *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.
+ logging_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
+ The logging strategy to adopt during training. Possible values are:
+
+ - `"no"`: No logging is done during training.
+ - `"epoch"`: Logging is done at the end of each epoch.
+ - `"steps"`: Logging is done every `logging_steps`.
+
+ logging_first_step (`bool`, *optional*, defaults to `False`):
+ Whether to log the first `global_step` or not.
+ logging_steps (`int` or `float`, *optional*, defaults to 500):
+ Number of update steps between two logs if `logging_strategy="steps"`. Should be an integer or a float in
+ range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps.
+ logging_nan_inf_filter (`bool`, *optional*, defaults to `True`):
+ Whether to filter `nan` and `inf` losses for logging. If set to `True` the loss of every step that is `nan`
+ or `inf` is filtered and the average loss of the current logging window is taken instead.
+
+
+
+ `logging_nan_inf_filter` only influences the logging of loss values, it does not change the behavior the
+ gradient is computed or applied to the model.
+
+
+
+ save_strategy (`str` or [`~trainer_utils.SaveStrategy`], *optional*, defaults to `"steps"`):
+ The checkpoint save strategy to adopt during training. Possible values are:
+
+ - `"no"`: No save is done during training.
+ - `"epoch"`: Save is done at the end of each epoch.
+ - `"steps"`: Save is done every `save_steps`.
+ - `"best"`: Save is done whenever a new `best_metric` is achieved.
+
+ If `"epoch"` or `"steps"` is chosen, saving will also be performed at the
+ very end of training, always.
+ save_steps (`int` or `float`, *optional*, defaults to 500):
+ Number of updates steps before two checkpoint saves if `save_strategy="steps"`. Should be an integer or a
+ float in range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps.
+ save_total_limit (`int`, *optional*):
+ If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
+ `output_dir`. When `load_best_model_at_end` is enabled, the "best" checkpoint according to
+ `metric_for_best_model` will always be retained in addition to the most recent ones. For example, for
+ `save_total_limit=5` and `load_best_model_at_end`, the four last checkpoints will always be retained
+ alongside the best model. When `save_total_limit=1` and `load_best_model_at_end`, it is possible that two
+ checkpoints are saved: the last one and the best one (if they are different).
+ save_safetensors (`bool`, *optional*, defaults to `True`):
+ Use [safetensors](https://huggingface.co/docs/safetensors) saving and loading for state dicts instead of
+ default `torch.load` and `torch.save`.
+ save_on_each_node (`bool`, *optional*, defaults to `False`):
+ When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on
+ the main one.
+
+ This should not be activated when the different nodes use the same storage as the files will be saved with
+ the same names for each node.
+ save_only_model (`bool`, *optional*, defaults to `False`):
+ When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state.
+ Note that when this is true, you won't be able to resume training from checkpoint.
+ This enables you to save storage by not storing the optimizer, scheduler & rng state.
+ You can only load the model using `from_pretrained` with this option set to `True`.
+ restore_callback_states_from_checkpoint (`bool`, *optional*, defaults to `False`):
+ Whether to restore the callback states from the checkpoint. If `True`, will override
+ callbacks passed to the `Trainer` if they exist in the checkpoint."
+ use_cpu (`bool`, *optional*, defaults to `False`):
+ Whether or not to use cpu. If set to False, we will use cuda or mps device if available.
+ seed (`int`, *optional*, defaults to 42):
+ Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the
+ [`~Trainer.model_init`] function to instantiate the model if it has some randomly initialized parameters.
+ data_seed (`int`, *optional*):
+ Random seed to be used with data samplers. If not set, random generators for data sampling will use the
+ same seed as `seed`. This can be used to ensure reproducibility of data sampling, independent of the model
+ seed.
+ jit_mode_eval (`bool`, *optional*, defaults to `False`):
+ Whether or not to use PyTorch jit trace for inference.
+ bf16 (`bool`, *optional*, defaults to `False`):
+ Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher
+ NVIDIA architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU.
+ fp16 (`bool`, *optional*, defaults to `False`):
+ Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training.
+ fp16_opt_level (`str`, *optional*, defaults to 'O1'):
+ For `fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details on
+ the [Apex documentation](https://nvidia.github.io/apex/amp).
+ fp16_backend (`str`, *optional*, defaults to `"auto"`):
+ This argument is deprecated. Use `half_precision_backend` instead.
+ half_precision_backend (`str`, *optional*, defaults to `"auto"`):
+ The backend to use for mixed precision training. Must be one of `"auto", "apex", "cpu_amp"`. `"auto"` will
+ use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices will force the
+ requested backend.
+ bf16_full_eval (`bool`, *optional*, defaults to `False`):
+ Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm
+ metric values.
+ fp16_full_eval (`bool`, *optional*, defaults to `False`):
+ Whether to use full float16 evaluation instead of 32-bit. This will be faster and save memory but can harm
+ metric values.
+ tf32 (`bool`, *optional*):
+ Whether to enable the TF32 mode, available in Ampere and newer GPU architectures. The default value depends
+ on PyTorch's version default of `torch.backends.cuda.matmul.allow_tf32`. For more details please refer to
+ the [TF32](https://huggingface.co/docs/transformers/perf_train_gpu_one#tf32) documentation. This is an
+ experimental API and it may change.
+ local_rank (`int`, *optional*, defaults to -1):
+ Rank of the process during distributed training.
+ ddp_backend (`str`, *optional*):
+ The backend to use for distributed training. Must be one of `"nccl"`, `"mpi"`, `"ccl"`, `"gloo"`, `"hccl"`.
+ tpu_num_cores (`int`, *optional*):
+ When training on TPU, the number of TPU cores (automatically passed by launcher script).
+ dataloader_drop_last (`bool`, *optional*, defaults to `False`):
+ Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
+ or not.
+ eval_steps (`int` or `float`, *optional*):
+ Number of update steps between two evaluations if `eval_strategy="steps"`. Will default to the same
+ value as `logging_steps` if not set. Should be an integer or a float in range `[0,1)`. If smaller than 1,
+ will be interpreted as ratio of total training steps.
+ dataloader_num_workers (`int`, *optional*, defaults to 0):
+ Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the
+ main process.
+ past_index (`int`, *optional*, defaults to -1):
+ Some models like [TransformerXL](../model_doc/transformerxl) or [XLNet](../model_doc/xlnet) can make use of
+ the past hidden states for their predictions. If this argument is set to a positive int, the `Trainer` will
+ use the corresponding output (usually index 2) as the past state and feed it to the model at the next
+ training step under the keyword argument `mems`.
+ run_name (`str`, *optional*, defaults to `output_dir`):
+ A descriptor for the run. Typically used for [trackio](https://github.com/gradio-app/trackio),
+ [wandb](https://www.wandb.com/), [mlflow](https://www.mlflow.org/), [comet](https://www.comet.com/site) and
+ [swanlab](https://swanlab.cn) logging. If not specified, will be the same as `output_dir`.
+ disable_tqdm (`bool`, *optional*):
+ Whether or not to disable the tqdm progress bars and table of metrics produced by
+ [`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is
+ set to warn or lower (default), `False` otherwise.
+ remove_unused_columns (`bool`, *optional*, defaults to `True`):
+ Whether or not to automatically remove the columns unused by the model forward method.
+ label_names (`list[str]`, *optional*):
+ The list of keys in your dictionary of inputs that correspond to the labels.
+
+ Will eventually default to the list of argument names accepted by the model that contain the word "label",
+ except if the model used is one of the `XxxForQuestionAnswering` in which case it will also include the
+ `["start_positions", "end_positions"]` keys.
+
+ You should only specify `label_names` if you're using custom label names or if your model's `forward` consumes multiple label tensors (e.g., extractive QA).
+ load_best_model_at_end (`bool`, *optional*, defaults to `False`):
+ Whether or not to load the best model found during training at the end of training. When this option is
+ enabled, the best checkpoint will always be saved. See
+ [`save_total_limit`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.save_total_limit)
+ for more.
+
+
+
+ When set to `True`, the parameters `save_strategy` needs to be the same as `eval_strategy`, and in
+ the case it is "steps", `save_steps` must be a round multiple of `eval_steps`.
+
+
+
+ metric_for_best_model (`str`, *optional*):
+ Use in conjunction with `load_best_model_at_end` to specify the metric to use to compare two different
+ models. Must be the name of a metric returned by the evaluation with or without the prefix `"eval_"`.
+
+ If not specified, this will default to `"loss"` when either `load_best_model_at_end == True`
+ or `lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU` (to use the evaluation loss).
+
+ If you set this value, `greater_is_better` will default to `True` unless the name ends with "loss".
+ Don't forget to set it to `False` if your metric is better when lower.
+ greater_is_better (`bool`, *optional*):
+ Use in conjunction with `load_best_model_at_end` and `metric_for_best_model` to specify if better models
+ should have a greater metric or not. Will default to:
+
+ - `True` if `metric_for_best_model` is set to a value that doesn't end in `"loss"`.
+ - `False` if `metric_for_best_model` is not set, or set to a value that ends in `"loss"`.
+ ignore_data_skip (`bool`, *optional*, defaults to `False`):
+ When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
+ stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step
+ can take a long time) but will not yield the same results as the interrupted training would have.
+ fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `None`):
+ Use PyTorch Distributed Parallel Training (in distributed training only).
+
+ A list of options along the following:
+
+ - `"full_shard"`: Shard parameters, gradients and optimizer states.
+ - `"shard_grad_op"`: Shard optimizer states and gradients.
+ - `"hybrid_shard"`: Apply `FULL_SHARD` within a node, and replicate parameters across nodes.
+ - `"hybrid_shard_zero2"`: Apply `SHARD_GRAD_OP` within a node, and replicate parameters across nodes.
+ - `"offload"`: Offload parameters and gradients to CPUs (only compatible with `"full_shard"` and
+ `"shard_grad_op"`).
+ - `"auto_wrap"`: Automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`.
+ fsdp_config (`str` or `dict`, *optional*):
+ Config to be used with fsdp (Pytorch Distributed Parallel Training). The value is either a location of
+ fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`.
+
+ A List of config and its options:
+ - min_num_params (`int`, *optional*, defaults to `0`):
+ FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is
+ passed).
+ - transformer_layer_cls_to_wrap (`list[str]`, *optional*):
+ List of transformer layer class names (case-sensitive) to wrap, e.g, `BertLayer`, `GPTJBlock`,
+ `T5Block` .... (useful only when `fsdp` flag is passed).
+ - backward_prefetch (`str`, *optional*)
+ FSDP's backward prefetch mode. Controls when to prefetch next set of parameters (useful only when
+ `fsdp` field is passed).
+
+ A list of options along the following:
+
+ - `"backward_pre"` : Prefetches the next set of parameters before the current set of parameter's
+ gradient computation.
+ - `"backward_post"` : This prefetches the next set of parameters after the current set of
+ parameter's gradient computation.
+ - forward_prefetch (`bool`, *optional*, defaults to `False`)
+ FSDP's forward prefetch mode (useful only when `fsdp` field is passed).
+ If `"True"`, then FSDP explicitly prefetches the next upcoming all-gather while executing in the
+ forward pass.
+ - limit_all_gathers (`bool`, *optional*, defaults to `False`)
+ FSDP's limit_all_gathers (useful only when `fsdp` field is passed).
+ If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight
+ all-gathers.
+ - use_orig_params (`bool`, *optional*, defaults to `True`)
+ If `"True"`, allows non-uniform `requires_grad` during init, which means support for interspersed
+ frozen and trainable parameters. Useful in cases such as parameter-efficient fine-tuning. Please
+ refer this
+ [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019
+ - sync_module_states (`bool`, *optional*, defaults to `True`)
+ If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to
+ ensure they are the same across all ranks after initialization
+ - cpu_ram_efficient_loading (`bool`, *optional*, defaults to `False`)
+ If `"True"`, only the first process loads the pretrained model checkpoint while all other processes
+ have empty weights. When this setting as `"True"`, `sync_module_states` also must to be `"True"`,
+ otherwise all the processes except the main process would have random weights leading to unexpected
+ behaviour during training.
+ - activation_checkpointing (`bool`, *optional*, defaults to `False`):
+ If `"True"`, activation checkpointing is a technique to reduce memory usage by clearing activations of
+ certain layers and recomputing them during a backward pass. Effectively, this trades extra
+ computation time for reduced memory usage.
+ - xla (`bool`, *optional*, defaults to `False`):
+ Whether to use PyTorch/XLA Fully Sharded Data Parallel Training. This is an experimental feature
+ and its API may evolve in the future.
+ - xla_fsdp_settings (`dict`, *optional*)
+ The value is a dictionary which stores the XLA FSDP wrapping parameters.
+
+ For a complete list of options, please see [here](
+ https://github.com/pytorch/xla/blob/master/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py).
+ - xla_fsdp_grad_ckpt (`bool`, *optional*, defaults to `False`):
+ Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be
+ used when the xla flag is set to true, and an auto wrapping policy is specified through
+ fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.
+ deepspeed (`str` or `dict`, *optional*):
+ Use [Deepspeed](https://github.com/deepspeedai/DeepSpeed). This is an experimental feature and its API may
+ evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
+ `ds_config.json`) or an already loaded json file as a `dict`"
+
+
+ If enabling any Zero-init, make sure that your model is not initialized until
+ *after* initializing the `TrainingArguments`, else it will not be applied.
+
+
+ accelerator_config (`str`, `dict`, or `AcceleratorConfig`, *optional*):
+ Config to be used with the internal `Accelerator` implementation. The value is either a location of
+ accelerator json config file (e.g., `accelerator_config.json`), an already loaded json file as `dict`,
+ or an instance of [`~trainer_pt_utils.AcceleratorConfig`].
+
+ A list of config and its options:
+ - split_batches (`bool`, *optional*, defaults to `False`):
+ Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If
+ `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a
+ round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set
+ in your script multiplied by the number of processes.
+ - dispatch_batches (`bool`, *optional*):
+ If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process
+ and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose
+ underlying dataset is an `IterableDataset`, `False` otherwise.
+ - even_batches (`bool`, *optional*, defaults to `True`):
+ If set to `True`, in cases where the total batch size across all processes does not exactly divide the
+ dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
+ all workers.
+ - use_seedable_sampler (`bool`, *optional*, defaults to `True`):
+ Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`]). Ensures
+ training results are fully reproducible using a different sampling technique. While seed-to-seed results
+ may differ, on average the differences are negligible when using multiple different seeds to compare. Should
+ also be ran with [`~utils.set_seed`] for the best results.
+ - use_configured_state (`bool`, *optional*, defaults to `False`):
+ Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`.
+ If `True`, an `Accelerator` or `PartialState` must be initialized. Note that by doing so, this could lead to issues
+ with hyperparameter tuning.
+ parallelism_config (`ParallelismConfig`, *optional*):
+ Parallelism configuration for the training run. Requires Accelerate `1.10.1`
+ label_smoothing_factor (`float`, *optional*, defaults to 0.0):
+ The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
+ labels are changed from 0s and 1s to `label_smoothing_factor/num_labels` and `1 - label_smoothing_factor +
+ label_smoothing_factor/num_labels` respectively.
+ debug (`str` or list of [`~debug_utils.DebugOption`], *optional*, defaults to `""`):
+ Enable one or more debug features. This is an experimental feature.
+
+ Possible options are:
+
+ - `"underflow_overflow"`: detects overflow in model's input/outputs and reports the last frames that led to
+ the event
+ - `"tpu_metrics_debug"`: print debug metrics on TPU
+
+ The options should be separated by whitespaces.
+ optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch"` (for torch>=2.8 `"adamw_torch_fused"`)):
+ The optimizer to use, such as "adamw_torch", "adamw_torch_fused", "adamw_apex_fused", "adamw_anyprecision",
+ "adafactor". See `OptimizerNames` in [training_args.py](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py)
+ for a full list of optimizers.
+ optim_args (`str`, *optional*):
+ Optional arguments that are supplied to optimizers such as AnyPrecisionAdamW, AdEMAMix, and GaLore.
+ group_by_length (`bool`, *optional*, defaults to `False`):
+ Whether or not to group together samples of roughly the same length in the training dataset (to minimize
+ padding applied and be more efficient). Only useful if applying dynamic padding.
+ length_column_name (`str`, *optional*, defaults to `"length"`):
+ Column name for precomputed lengths. If the column exists, grouping by length will use these values rather
+ than computing them on train startup. Ignored unless `group_by_length` is `True` and the dataset is an
+ instance of `Dataset`.
+ report_to (`str` or `list[str]`, *optional*, defaults to `"all"`):
+ The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
+ `"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`, `"neptune"`,
+ `"swanlab"`, `"tensorboard"`, `"trackio"` and `"wandb"`. Use `"all"` to report to all integrations
+ installed, `"none"` for no integrations.
+ project (`str`, *optional*, defaults to `"huggingface"`):
+ The name of the project to use for logging. Currently, only used by Trackio.
+ trackio_space_id (`str` or `None`, *optional*, defaults to `"trackio"`):
+ The Hugging Face Space ID to deploy to when using Trackio. Should be a complete Space name like
+ `'username/reponame'` or `'orgname/reponame' `, or just `'reponame'` in which case the Space will be
+ created in the currently-logged-in Hugging Face user's namespace. If `None`, will log to a local directory.
+ Note that this Space will be public unless you set `hub_private_repo=True` or your organization's default
+ is to create private Spaces."
+ ddp_find_unused_parameters (`bool`, *optional*):
+ When using distributed training, the value of the flag `find_unused_parameters` passed to
+ `DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
+ ddp_bucket_cap_mb (`int`, *optional*):
+ When using distributed training, the value of the flag `bucket_cap_mb` passed to `DistributedDataParallel`.
+ ddp_broadcast_buffers (`bool`, *optional*):
+ When using distributed training, the value of the flag `broadcast_buffers` passed to
+ `DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
+ dataloader_pin_memory (`bool`, *optional*, defaults to `True`):
+ Whether you want to pin memory in data loaders or not. Will default to `True`.
+ dataloader_persistent_workers (`bool`, *optional*, defaults to `False`):
+ If True, the data loader will not shut down the worker processes after a dataset has been consumed once.
+ This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will
+ increase RAM usage. Will default to `False`.
+ dataloader_prefetch_factor (`int`, *optional*):
+ Number of batches loaded in advance by each worker.
+ 2 means there will be a total of 2 * num_workers batches prefetched across all workers.
+ skip_memory_metrics (`bool`, *optional*, defaults to `True`):
+ Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows
+ down the training and evaluation speed.
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push the model to the Hub every time the model is saved. If this is activated,
+ `output_dir` will begin a git directory synced with the repo (determined by `hub_model_id`) and the content
+ will be pushed each time a save is triggered (depending on your `save_strategy`). Calling
+ [`~Trainer.save_model`] will also trigger a push.
+
+
+
+ If `output_dir` exists, it needs to be a local clone of the repository to which the [`Trainer`] will be
+ pushed.
+
+
+
+ resume_from_checkpoint (`str`, *optional*):
+ The path to a folder with a valid checkpoint for your model. This argument is not directly used by
+ [`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example
+ scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
+ hub_model_id (`str`, *optional*):
+ The name of the repository to keep in sync with the local *output_dir*. It can be a simple model ID in
+ which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,
+ for instance `"user_name/model"`, which allows you to push to an organization you are a member of with
+ `"organization_name/model"`. Will default to `user_name/output_dir_name` with *output_dir_name* being the
+ name of `output_dir`.
+
+ Will default to the name of `output_dir`.
+ hub_strategy (`str` or [`~trainer_utils.HubStrategy`], *optional*, defaults to `"every_save"`):
+ Defines the scope of what is pushed to the Hub and when. Possible values are:
+
+ - `"end"`: push the model, its configuration, the processing class e.g. tokenizer (if passed along to the [`Trainer`]) and a
+ draft of a model card when the [`~Trainer.save_model`] method is called.
+ - `"every_save"`: push the model, its configuration, the processing class e.g. tokenizer (if passed along to the [`Trainer`]) and
+ a draft of a model card each time there is a model save. The pushes are asynchronous to not block
+ training, and in case the save are very frequent, a new push is only attempted if the previous one is
+ finished. A last push is made with the final model at the end of training.
+ - `"checkpoint"`: like `"every_save"` but the latest checkpoint is also pushed in a subfolder named
+ last-checkpoint, allowing you to resume training easily with
+ `trainer.train(resume_from_checkpoint="last-checkpoint")`.
+ - `"all_checkpoints"`: like `"checkpoint"` but all checkpoints are pushed like they appear in the output
+ folder (so you will get one checkpoint folder per folder in your final repository)
+
+ hub_token (`str`, *optional*):
+ The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
+ `hf auth login`.
+ hub_private_repo (`bool`, *optional*):
+ Whether to make the repo private. If `None` (default), the repo will be public unless the organization's
+ default is private. This value is ignored if the repo already exists. If reporting to Trackio with
+ deployment to Hugging Face Spaces enabled, the same logic determines whether the Space is private.
+ hub_always_push (`bool`, *optional*, defaults to `False`):
+ Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not finished.
+ hub_revision (`str`, *optional*):
+ The revision to use when pushing to the Hub. Can be a branch name, a tag, or a commit hash.
+ gradient_checkpointing (`bool`, *optional*, defaults to `False`):
+ If True, use gradient checkpointing to save memory at the expense of slower backward pass.
+ gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`):
+ Key word arguments to be passed to the `gradient_checkpointing_enable` method.
+ include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
+ This argument is deprecated. Use `include_for_metrics` instead, e.g, `include_for_metrics = ["inputs"]`.
+ include_for_metrics (`list[str]`, *optional*, defaults to `[]`):
+ Include additional data in the `compute_metrics` function if needed for metrics computation.
+ Possible options to add to `include_for_metrics` list:
+ - `"inputs"`: Input data passed to the model, intended for calculating input dependent metrics.
+ - `"loss"`: Loss values computed during evaluation, intended for calculating loss dependent metrics.
+ eval_do_concat_batches (`bool`, *optional*, defaults to `True`):
+ Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`,
+ will instead store them as lists, with each batch kept separate.
+ auto_find_batch_size (`bool`, *optional*, defaults to `False`)
+ Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding
+ CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
+ full_determinism (`bool`, *optional*, defaults to `False`)
+ If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
+ distributed training. Important: this will negatively impact the performance, so only use it for debugging.
+ torchdynamo (`str`, *optional*):
+ If set, the backend compiler for TorchDynamo. Possible choices are `"eager"`, `"aot_eager"`, `"inductor"`,
+ `"nvfuser"`, `"aot_nvfuser"`, `"aot_cudagraphs"`, `"ofi"`, `"fx2trt"`, `"onnxrt"` and `"ipex"`.
+ ray_scope (`str`, *optional*, defaults to `"last"`):
+ The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray will
+ then use the last checkpoint of all trials, compare those, and select the best one. However, other options
+ are also available. See the [Ray documentation](
+ https://docs.ray.io/en/latest/tune/api_docs/analysis.html#ray.tune.ExperimentAnalysis.get_best_trial) for
+ more options.
+ ddp_timeout (`int`, *optional*, defaults to 1800):
+ The timeout for `torch.distributed.init_process_group` calls, used to avoid GPU socket timeouts when
+ performing slow operations in distributed runnings. Please refer the [PyTorch documentation]
+ (https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more
+ information.
+ use_mps_device (`bool`, *optional*, defaults to `False`):
+ This argument is deprecated.`mps` device will be used if it is available similar to `cuda` device.
+ torch_compile (`bool`, *optional*, defaults to `False`):
+ Whether or not to compile the model using PyTorch 2.0
+ [`torch.compile`](https://pytorch.org/get-started/pytorch-2.0/).
+
+ This will use the best defaults for the [`torch.compile`
+ API](https://pytorch.org/docs/stable/generated/torch.compile.html?highlight=torch+compile#torch.compile).
+ You can customize the defaults with the argument `torch_compile_backend` and `torch_compile_mode` but we
+ don't guarantee any of them will work as the support is progressively rolled in in PyTorch.
+
+ This flag and the whole compile API is experimental and subject to change in future releases.
+ torch_compile_backend (`str`, *optional*):
+ The backend to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`.
+
+ Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.
+
+ This flag is experimental and subject to change in future releases.
+ torch_compile_mode (`str`, *optional*):
+ The mode to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`.
+
+ Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.
+
+ This flag is experimental and subject to change in future releases.
+ include_tokens_per_second (`bool`, *optional*, defaults to `False`):
+ Whether or not to compute the number of tokens per second per device for training speed metrics.
+
+ This will iterate over the entire training dataloader once beforehand,
+ and will slow down the entire process.
+
+ include_num_input_tokens_seen (`bool`, *optional*):
+ Whether or not to track the number of input tokens seen throughout training.
+
+ May be slower in distributed training as gather operations must be called.
+
+ neftune_noise_alpha (`Optional[float]`):
+ If not `None`, this will activate NEFTune noise embeddings. This can drastically improve model performance
+ for instruction fine-tuning. Check out the [original paper](https://huggingface.co/papers/2310.05914) and the
+ [original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also
+ `PeftModel` from peft. The original paper used values in the range [5.0, 15.0].
+ optim_target_modules (`Union[str, list[str]]`, *optional*):
+ The target modules to optimize, i.e. the module names that you would like to train.
+ Currently used for the GaLore algorithm (https://huggingface.co/papers/2403.03507) and APOLLO algorithm (https://huggingface.co/papers/2412.05270).
+ See GaLore implementation (https://github.com/jiaweizzhao/GaLore) and APOLLO implementation (https://github.com/zhuhanqing/APOLLO) for more details.
+ You need to make sure to pass a valid GaLore or APOLLO optimizer, e.g., one of: "apollo_adamw", "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules only.
+
+ batch_eval_metrics (`bool`, *optional*, defaults to `False`):
+ If set to `True`, evaluation will call compute_metrics at the end of each batch to accumulate statistics
+ rather than saving all eval logits in memory. When set to `True`, you must pass a compute_metrics function
+ that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global
+ summary statistics from the batch-level summary statistics you've accumulated over the evaluation set.
+
+ eval_on_start (`bool`, *optional*, defaults to `False`):
+ Whether to perform a evaluation step (sanity check) before the training to ensure the validation steps works correctly.
+
+ eval_use_gather_object (`bool`, *optional*, defaults to `False`):
+ Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices. This should only be enabled if users are not just returning tensors, and this is actively discouraged by PyTorch.
+
+ use_liger_kernel (`bool`, *optional*, defaults to `False`):
+ Whether enable [Liger](https://github.com/linkedin/Liger-Kernel) Kernel for LLM model training.
+ It can effectively increase multi-GPU training throughput by ~20% and reduces memory usage by ~60%, works out of the box with
+ flash attention, PyTorch FSDP, and Microsoft DeepSpeed. Currently, it supports llama, mistral, mixtral and gemma models.
+
+ liger_kernel_config (`Optional[dict]`, *optional*):
+ Configuration to be used for Liger Kernel. When use_liger_kernel=True, this dict is passed as keyword arguments to the
+ `_apply_liger_kernel_to_instance` function, which specifies which kernels to apply. Available options vary by model but typically
+ include: 'rope', 'swiglu', 'cross_entropy', 'fused_linear_cross_entropy', 'rms_norm', etc. If `None`, use the default kernel configurations.
+
+ average_tokens_across_devices (`bool`, *optional*, defaults to `True`):
+ Whether or not to average tokens across devices. If enabled, will use all_reduce to synchronize
+ num_tokens_in_batch for precise loss calculation. Reference:
+ https://github.com/huggingface/transformers/issues/34242
+ """
+
+ # Sometimes users will pass in a `str` repr of a dict in the CLI
+ # We need to track what fields those can be. Each time a new arg
+ # has a dict type, it must be added to this list.
+ # Important: These should be typed with Optional[Union[dict,str,...]]
+ _VALID_DICT_FIELDS = [
+ "accelerator_config",
+ "fsdp_config",
+ "deepspeed",
+ "gradient_checkpointing_kwargs",
+ "lr_scheduler_kwargs",
+ ]
+ framework = "pt"
+
+ output_dir: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "The output directory where the model predictions and checkpoints will be written. Defaults to 'trainer_output' if not provided."
+ },
+ )
+ overwrite_output_dir: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Overwrite the content of the output directory. "
+ "Use this to continue training if output_dir points to a checkpoint directory."
+ )
+ },
+ )
+
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
+ do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
+ eval_strategy: Union[IntervalStrategy, str] = field(
+ default="no",
+ metadata={"help": "The evaluation strategy to use."},
+ )
+ prediction_loss_only: bool = field(
+ default=False,
+ metadata={"help": "When performing evaluation and predictions, only returns the loss."},
+ )
+
+ per_device_train_batch_size: int = field(
+ default=8, metadata={"help": "Batch size per device accelerator core/CPU for training."}
+ )
+ per_device_eval_batch_size: int = field(
+ default=8, metadata={"help": "Batch size per device accelerator core/CPU for evaluation."}
+ )
+
+ per_gpu_train_batch_size: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Deprecated, the use of `--per_device_train_batch_size` is preferred. "
+ "Batch size per GPU/TPU core/CPU for training."
+ )
+ },
+ )
+ per_gpu_eval_batch_size: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Deprecated, the use of `--per_device_eval_batch_size` is preferred. "
+ "Batch size per GPU/TPU core/CPU for evaluation."
+ )
+ },
+ )
+
+ gradient_accumulation_steps: int = field(
+ default=1,
+ metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."},
+ )
+ eval_accumulation_steps: Optional[int] = field(
+ default=None,
+ metadata={"help": "Number of predictions steps to accumulate before moving the tensors to the CPU."},
+ )
+
+ eval_delay: float = field(
+ default=0,
+ metadata={
+ "help": (
+ "Number of epochs or steps to wait for before the first evaluation can be performed, depending on the"
+ " eval_strategy."
+ )
+ },
+ )
+
+ torch_empty_cache_steps: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": "Number of steps to wait before calling `torch..empty_cache()`."
+ "This can help avoid CUDA out-of-memory errors by lowering peak VRAM usage at a cost of about [10% slower performance](https://github.com/huggingface/transformers/issues/31372)."
+ "If left unset or set to None, cache will not be emptied."
+ },
+ )
+
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
+ max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."})
+
+ num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
+ max_steps: int = field(
+ default=-1,
+ metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."},
+ )
+ lr_scheduler_type: Union[SchedulerType, str] = field(
+ default="linear",
+ metadata={"help": "The scheduler type to use."},
+ )
+ lr_scheduler_kwargs: Optional[Union[dict, str]] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts."
+ )
+ },
+ )
+ warmup_ratio: float = field(
+ default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}
+ )
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
+
+ log_level: str = field(
+ default="passive",
+ metadata={
+ "help": (
+ "Logger log level to use on the main node. Possible choices are the log levels as strings: 'debug',"
+ " 'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and"
+ " lets the application set the level. Defaults to 'passive'."
+ ),
+ "choices": trainer_log_levels.keys(),
+ },
+ )
+ log_level_replica: str = field(
+ default="warning",
+ metadata={
+ "help": "Logger log level to use on replica nodes. Same choices and defaults as ``log_level``",
+ "choices": trainer_log_levels.keys(),
+ },
+ )
+ log_on_each_node: bool = field(
+ default=True,
+ metadata={
+ "help": (
+ "When doing a multinode distributed training, whether to log once per node or just once on the main"
+ " node."
+ )
+ },
+ )
+ logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."})
+ logging_strategy: Union[IntervalStrategy, str] = field(
+ default="steps",
+ metadata={"help": "The logging strategy to use."},
+ )
+ logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
+ logging_steps: float = field(
+ default=500,
+ metadata={
+ "help": (
+ "Log every X updates steps. Should be an integer or a float in range `[0,1)`. "
+ "If smaller than 1, will be interpreted as ratio of total training steps."
+ )
+ },
+ )
+ logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."})
+ save_strategy: Union[SaveStrategy, str] = field(
+ default="steps",
+ metadata={"help": "The checkpoint save strategy to use."},
+ )
+ save_steps: float = field(
+ default=500,
+ metadata={
+ "help": (
+ "Save checkpoint every X updates steps. Should be an integer or a float in range `[0,1)`. "
+ "If smaller than 1, will be interpreted as ratio of total training steps."
+ )
+ },
+ )
+ save_total_limit: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in"
+ " `output_dir`. When `load_best_model_at_end` is enabled, the 'best' checkpoint according to"
+ " `metric_for_best_model` will always be retained in addition to the most recent ones. For example,"
+ " for `save_total_limit=5` and `load_best_model_at_end=True`, the four last checkpoints will always be"
+ " retained alongside the best model. When `save_total_limit=1` and `load_best_model_at_end=True`,"
+ " it is possible that two checkpoints are saved: the last one and the best one (if they are different)."
+ " Default is unlimited checkpoints"
+ )
+ },
+ )
+ save_safetensors: bool = field(
+ default=True,
+ metadata={
+ "help": "Use safetensors saving and loading for state dicts instead of default torch.load and torch.save."
+ },
+ )
+ save_on_each_node: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "When doing multi-node distributed training, whether to save models and checkpoints on each node, or"
+ " only on the main one"
+ )
+ },
+ )
+ save_only_model: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state."
+ "Note that when this is true, you won't be able to resume training from checkpoint."
+ "This enables you to save storage by not storing the optimizer, scheduler & rng state."
+ "You can only load the model using from_pretrained with this option set to True."
+ )
+ },
+ )
+ restore_callback_states_from_checkpoint: bool = field(
+ default=False,
+ metadata={
+ "help": "Whether to restore the callback states from the checkpoint. If `True`, will override callbacks passed to the `Trainer` if they exist in the checkpoint."
+ },
+ )
+ no_cuda: bool = field(
+ default=False,
+ metadata={"help": "This argument is deprecated. It will be removed in version 5.0 of 🤗 Transformers."},
+ )
+ use_cpu: bool = field(
+ default=False,
+ metadata={
+ "help": "Whether or not to use cpu. If left to False, we will use the available torch device/backend (cuda/mps/xpu/hpu etc.)"
+ },
+ )
+ use_mps_device: bool = field(
+ default=False,
+ metadata={
+ "help": "This argument is deprecated. `mps` device will be used if available similar to `cuda` device."
+ " It will be removed in version 5.0 of 🤗 Transformers"
+ },
+ )
+ seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
+ data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
+ jit_mode_eval: bool = field(
+ default=False, metadata={"help": "Whether or not to use PyTorch jit trace for inference"}
+ )
+ bf16: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA"
+ " architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change."
+ )
+ },
+ )
+ fp16: bool = field(
+ default=False,
+ metadata={"help": "Whether to use fp16 (mixed) precision instead of 32-bit"},
+ )
+ fp16_opt_level: str = field(
+ default="O1",
+ metadata={
+ "help": (
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. "
+ "See details at https://nvidia.github.io/apex/amp.html"
+ )
+ },
+ )
+ half_precision_backend: str = field(
+ default="auto",
+ metadata={
+ "help": "The backend to be used for half precision.",
+ "choices": ["auto", "apex", "cpu_amp"],
+ },
+ )
+ bf16_full_eval: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to use full bfloat16 evaluation instead of 32-bit. This is an experimental API and it may"
+ " change."
+ )
+ },
+ )
+ fp16_full_eval: bool = field(
+ default=False,
+ metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"},
+ )
+ tf32: Optional[bool] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental"
+ " API and it may change."
+ )
+ },
+ )
+ local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
+ ddp_backend: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "The backend to be used for distributed training",
+ "choices": ["nccl", "gloo", "mpi", "ccl", "hccl", "cncl", "mccl"],
+ },
+ )
+ tpu_num_cores: Optional[int] = field(
+ default=None, metadata={"help": "TPU: Number of TPU cores (automatically passed by launcher script)"}
+ )
+ tpu_metrics_debug: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Deprecated, the use of `--debug tpu_metrics_debug` is preferred. TPU: Whether to print debug metrics"
+ )
+ },
+ )
+ debug: Union[str, list[DebugOption]] = field(
+ default="",
+ metadata={
+ "help": (
+ "Whether or not to enable debug mode. Current options: "
+ "`underflow_overflow` (Detect underflow and overflow in activations and weights), "
+ "`tpu_metrics_debug` (print debug metrics on TPU)."
+ )
+ },
+ )
+
+ dataloader_drop_last: bool = field(
+ default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
+ )
+ eval_steps: Optional[float] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Run an evaluation every X steps. Should be an integer or a float in range `[0,1)`. "
+ "If smaller than 1, will be interpreted as ratio of total training steps."
+ )
+ },
+ )
+ dataloader_num_workers: int = field(
+ default=0,
+ metadata={
+ "help": (
+ "Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded"
+ " in the main process."
+ )
+ },
+ )
+ dataloader_prefetch_factor: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Number of batches loaded in advance by each worker. "
+ "2 means there will be a total of 2 * num_workers batches prefetched across all workers. "
+ )
+ },
+ )
+ past_index: int = field(
+ default=-1,
+ metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."},
+ )
+
+ run_name: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": (
+ "An optional descriptor for the run. Notably used for trackio, wandb, mlflow comet and swanlab "
+ "logging."
+ )
+ },
+ )
+ disable_tqdm: Optional[bool] = field(
+ default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."}
+ )
+
+ remove_unused_columns: bool = field(
+ default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."}
+ )
+ label_names: Optional[list[str]] = field(
+ default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."}
+ )
+ load_best_model_at_end: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether or not to load the best model found during training at the end of training. When this option"
+ " is enabled, the best checkpoint will always be saved. See `save_total_limit` for more."
+ )
+ },
+ )
+ metric_for_best_model: Optional[str] = field(
+ default=None, metadata={"help": "The metric to use to compare two different models."}
+ )
+ greater_is_better: Optional[bool] = field(
+ default=None, metadata={"help": "Whether the `metric_for_best_model` should be maximized or not."}
+ )
+ ignore_data_skip: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "When resuming training, whether or not to skip the first epochs and batches to get to the same"
+ " training data."
+ )
+ },
+ )
+ fsdp: Optional[Union[list[FSDPOption], str]] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Whether or not to use PyTorch Fully Sharded Data Parallel (FSDP) training (in distributed training"
+ " only). The base option should be `full_shard`, `shard_grad_op` or `no_shard` and you can add"
+ " CPU-offload to `full_shard` or `shard_grad_op` like this: full_shard offload` or `shard_grad_op"
+ " offload`. You can add auto-wrap to `full_shard` or `shard_grad_op` with the same syntax: full_shard"
+ " auto_wrap` or `shard_grad_op auto_wrap`."
+ ),
+ },
+ )
+ fsdp_min_num_params: int = field(
+ default=0,
+ metadata={
+ "help": (
+ "This parameter is deprecated. FSDP's minimum number of parameters for Default Auto Wrapping. (useful"
+ " only when `fsdp` field is passed)."
+ )
+ },
+ )
+ fsdp_config: Optional[Union[dict[str, Any], str]] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a "
+ "fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`."
+ )
+ },
+ )
+ fsdp_transformer_layer_cls_to_wrap: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": (
+ "This parameter is deprecated. Transformer layer class name (case-sensitive) to wrap, e.g,"
+ " `BertLayer`, `GPTJBlock`, `T5Block` .... (useful only when `fsdp` flag is passed)."
+ )
+ },
+ )
+ accelerator_config: Optional[Union[dict, str]] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Config to be used with the internal Accelerator object initialization. The value is either a "
+ "accelerator json config file (e.g., `accelerator_config.json`) or an already loaded json file as `dict`."
+ )
+ },
+ )
+ parallelism_config: Optional[ParallelismConfig] = field(
+ default=None,
+ metadata={"help": ("Parallelism configuration for the training run. Requires Accelerate `1.10.1`")},
+ )
+ deepspeed: Optional[Union[dict, str]] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Enable deepspeed and pass the path to deepspeed json config file (e.g. `ds_config.json`) or an already"
+ " loaded json file as a dict"
+ )
+ },
+ )
+ label_smoothing_factor: float = field(
+ default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
+ )
+
+ default_optim = "adamw_torch"
+ if is_torch_available():
+ from .pytorch_utils import is_torch_greater_or_equal_than_2_8
+
+ if is_torch_greater_or_equal_than_2_8:
+ default_optim = "adamw_torch_fused"
+ optim: Union[OptimizerNames, str] = field(
+ default=default_optim,
+ metadata={"help": "The optimizer to use."},
+ )
+ optim_args: Optional[str] = field(default=None, metadata={"help": "Optional arguments to supply to optimizer."})
+ adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
+ group_by_length: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to group samples of roughly the same length together when batching."},
+ )
+ length_column_name: str = field(
+ default="length",
+ metadata={"help": "Column name with precomputed lengths to use when grouping by length."},
+ )
+ report_to: Union[None, str, list[str]] = field(
+ default=None, metadata={"help": "The list of integrations to report the results and logs to."}
+ )
+ project: str = field(
+ default="huggingface",
+ metadata={"help": "The name of the project to use for logging. Currenly, only used by Trackio."},
+ )
+ trackio_space_id: Optional[str] = field(
+ default="trackio",
+ metadata={
+ "help": "The Hugging Face Space ID to deploy to when using Trackio. Should be a complete Space name like "
+ "'username/reponame' or 'orgname/reponame', or just 'reponame' in which case the Space will be created in "
+ "the currently-logged-in Hugging Face user's namespace. If `None`, will log to a local directory. Note "
+ "that this Space will be public unless you set `hub_private_repo=True` or your organization's "
+ "default is to create private Spaces."
+ },
+ )
+ ddp_find_unused_parameters: Optional[bool] = field(
+ default=None,
+ metadata={
+ "help": (
+ "When using distributed training, the value of the flag `find_unused_parameters` passed to "
+ "`DistributedDataParallel`."
+ )
+ },
+ )
+ ddp_bucket_cap_mb: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "When using distributed training, the value of the flag `bucket_cap_mb` passed to "
+ "`DistributedDataParallel`."
+ )
+ },
+ )
+ ddp_broadcast_buffers: Optional[bool] = field(
+ default=None,
+ metadata={
+ "help": (
+ "When using distributed training, the value of the flag `broadcast_buffers` passed to "
+ "`DistributedDataParallel`."
+ )
+ },
+ )
+ dataloader_pin_memory: bool = field(
+ default=True, metadata={"help": "Whether or not to pin memory for DataLoader."}
+ )
+ dataloader_persistent_workers: bool = field(
+ default=False,
+ metadata={
+ "help": "If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will increase RAM usage."
+ },
+ )
+ skip_memory_metrics: bool = field(
+ default=True, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."}
+ )
+ use_legacy_prediction_loop: bool = field(
+ default=False, metadata={"help": "Whether or not to use the legacy prediction_loop in the Trainer."}
+ )
+ push_to_hub: bool = field(
+ default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
+ )
+ resume_from_checkpoint: Optional[str] = field(
+ default=None,
+ metadata={"help": "The path to a folder with a valid checkpoint for your model."},
+ )
+ hub_model_id: Optional[str] = field(
+ default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
+ )
+ hub_strategy: Union[HubStrategy, str] = field(
+ default="every_save",
+ metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
+ )
+ hub_token: Optional[str] = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
+ hub_private_repo: Optional[bool] = field(
+ default=None,
+ metadata={
+ "help": "Whether to make the repo private. If `None` (default), the repo will be public unless the "
+ "organization's default is private. This value is ignored if the repo already exists. If reporting to "
+ "Trackio with deployment to Hugging Face Spaces enabled, the same logic determines whether the Space is "
+ "private."
+ },
+ )
+ hub_always_push: bool = field(
+ default=False,
+ metadata={"help": "Unless `True`, the Trainer will skip pushes if the previous one wasn't finished yet."},
+ )
+ hub_revision: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "The revision to use when pushing to the Hub. Can be a branch name, a tag, or a commit hash."
+ },
+ )
+ gradient_checkpointing: bool = field(
+ default=False,
+ metadata={
+ "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
+ },
+ )
+ gradient_checkpointing_kwargs: Optional[Union[dict[str, Any], str]] = field(
+ default=None,
+ metadata={
+ "help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`."
+ },
+ )
+ include_inputs_for_metrics: bool = field(
+ default=False,
+ metadata={
+ "help": "This argument is deprecated and will be removed in version 5 of 🤗 Transformers. Use `include_for_metrics` instead."
+ },
+ )
+ include_for_metrics: list[str] = field(
+ default_factory=list,
+ metadata={
+ "help": "List of strings to specify additional data to include in the `compute_metrics` function."
+ "Options: 'inputs', 'loss'."
+ },
+ )
+ eval_do_concat_batches: bool = field(
+ default=True,
+ metadata={
+ "help": "Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`, will instead store them as lists, with each batch kept separate."
+ },
+ )
+ # Deprecated arguments
+ fp16_backend: str = field(
+ default="auto",
+ metadata={
+ "help": "Deprecated. Use half_precision_backend instead",
+ "choices": ["auto", "apex", "cpu_amp"],
+ },
+ )
+ push_to_hub_model_id: Optional[str] = field(
+ default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
+ )
+ push_to_hub_organization: Optional[str] = field(
+ default=None, metadata={"help": "The name of the organization in with to which push the `Trainer`."}
+ )
+ push_to_hub_token: Optional[str] = field(
+ default=None, metadata={"help": "The token to use to push to the Model Hub."}
+ )
+ _n_gpu: int = field(init=False, repr=False, default=-1)
+ mp_parameters: str = field(
+ default="",
+ metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in Trainer"},
+ )
+
+ auto_find_batch_size: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to automatically decrease the batch size in half and rerun the training loop again each time"
+ " a CUDA Out-of-Memory was reached"
+ )
+ },
+ )
+ full_determinism: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to call enable_full_determinism instead of set_seed for reproducibility in distributed"
+ " training. Important: this will negatively impact the performance, so only use it for debugging."
+ )
+ },
+ )
+ torchdynamo: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "This argument is deprecated, use `--torch_compile_backend` instead.",
+ },
+ )
+ ray_scope: Optional[str] = field(
+ default="last",
+ metadata={
+ "help": (
+ 'The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray'
+ " will then use the last checkpoint of all trials, compare those, and select the best one. However,"
+ " other options are also available. See the Ray documentation"
+ " (https://docs.ray.io/en/latest/tune/api_docs/analysis.html"
+ "#ray.tune.ExperimentAnalysis.get_best_trial)"
+ " for more options."
+ )
+ },
+ )
+ ddp_timeout: int = field(
+ default=1800,
+ metadata={
+ "help": "Overrides the default timeout for distributed training (value should be given in seconds)."
+ },
+ )
+ torch_compile: bool = field(
+ default=False, metadata={"help": "If set to `True`, the model will be wrapped in `torch.compile`."}
+ )
+ torch_compile_backend: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "Which backend to use with `torch.compile`, passing one will trigger a model compilation.",
+ },
+ )
+ torch_compile_mode: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "Which mode to use with `torch.compile`, passing one will trigger a model compilation.",
+ },
+ )
+
+ include_tokens_per_second: bool = field(
+ default=False,
+ metadata={"help": "If set to `True`, the speed metrics will include `tgs` (tokens per second per device)."},
+ )
+
+ include_num_input_tokens_seen: Union[str, bool] = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to track the number of input tokens seen. "
+ "Can be `'all'` to count all tokens, `'non_padding'` to count only non-padding tokens, "
+ "or a boolean (`True` maps to `'all'`, `False` to `'no'`)."
+ )
+ },
+ )
+
+ neftune_noise_alpha: Optional[float] = field(
+ default=None,
+ metadata={
+ "help": "Activates neftune noise embeddings into the model. NEFTune has been proven to drastically improve model performances for instruction fine-tuning. Check out the original paper here: https://huggingface.co/papers/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune. Only supported for `PreTrainedModel` and `PeftModel` classes."
+ },
+ )
+
+ optim_target_modules: Union[None, str, list[str]] = field(
+ default=None,
+ metadata={
+ "help": "Target modules for the optimizer defined in the `optim` argument. Only used for the GaLore optimizer at the moment."
+ },
+ )
+
+ batch_eval_metrics: bool = field(
+ default=False,
+ metadata={"help": "Break eval metrics calculation into batches to save memory."},
+ )
+
+ eval_on_start: bool = field(
+ default=False,
+ metadata={
+ "help": "Whether to run through the entire `evaluation` step at the very beginning of training as a sanity check."
+ },
+ )
+
+ use_liger_kernel: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to enable the Liger Kernel for model training."},
+ )
+
+ liger_kernel_config: Optional[dict[str, bool]] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Configuration to be used for Liger Kernel. When use_liger_kernel=True, "
+ "this dict is passed as keyword arguments to the `_apply_liger_kernel_to_instance` function, "
+ "which specifies which kernels to apply. Available options vary by model "
+ "but typically include: 'rope', 'swiglu', 'cross_entropy', 'fused_linear_cross_entropy', "
+ "'rms_norm', etc. If None, use the default kernel configurations."
+ )
+ },
+ )
+
+ eval_use_gather_object: bool = field(
+ default=False,
+ metadata={
+ "help": "Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices."
+ },
+ )
+
+ average_tokens_across_devices: bool = field(
+ default=True,
+ metadata={
+ "help": "Whether or not to average tokens across devices. If enabled, will use all_reduce to "
+ "synchronize num_tokens_in_batch for precise loss calculation. Reference: "
+ "https://github.com/huggingface/transformers/issues/34242"
+ },
+ )
+
+ def __post_init__(self):
+ # Set default output_dir if not provided
+ if self.output_dir is None:
+ self.output_dir = "trainer_output"
+ logger.info(
+ "No output directory specified, defaulting to 'trainer_output'. "
+ "To change this behavior, specify --output_dir when creating TrainingArguments."
+ )
+
+ # Parse in args that could be `dict` sent in from the CLI as a string
+ for field in self._VALID_DICT_FIELDS:
+ passed_value = getattr(self, field)
+ # We only want to do this if the str starts with a bracket to indicate a `dict`
+ # else its likely a filename if supported
+ if isinstance(passed_value, str) and passed_value.startswith("{"):
+ loaded_dict = json.loads(passed_value)
+ # Convert str values to types if applicable
+ loaded_dict = _convert_str_dict(loaded_dict)
+ setattr(self, field, loaded_dict)
+
+ # expand paths, if not os.makedirs("~/bar") will make directory
+ # in the current directory instead of the actual home
+ # see https://github.com/huggingface/transformers/issues/10628
+ if self.output_dir is not None:
+ self.output_dir = os.path.expanduser(self.output_dir)
+ if self.logging_dir is None and self.output_dir is not None:
+ self.logging_dir = os.path.join(self.output_dir, default_logdir())
+ if self.logging_dir is not None:
+ self.logging_dir = os.path.expanduser(self.logging_dir)
+
+ if self.disable_tqdm is None:
+ self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
+
+ if isinstance(self.eval_strategy, EvaluationStrategy):
+ warnings.warn(
+ "using `EvaluationStrategy` for `eval_strategy` is deprecated and will be removed in version 5"
+ " of 🤗 Transformers. Use `IntervalStrategy` instead",
+ FutureWarning,
+ )
+ # Go back to the underlying string or we won't be able to instantiate `IntervalStrategy` on it.
+ self.eval_strategy = self.eval_strategy.value
+ if self.no_cuda:
+ warnings.warn(
+ "using `no_cuda` is deprecated and will be removed in version 5.0 of 🤗 Transformers. "
+ "Use `use_cpu` instead",
+ FutureWarning,
+ )
+ self.use_cpu = self.no_cuda
+
+ self.eval_strategy = IntervalStrategy(self.eval_strategy)
+ self.logging_strategy = IntervalStrategy(self.logging_strategy)
+ self.save_strategy = SaveStrategy(self.save_strategy)
+ self.hub_strategy = HubStrategy(self.hub_strategy)
+
+ self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
+ if self.do_eval is False and self.eval_strategy != IntervalStrategy.NO:
+ self.do_eval = True
+
+ if self.torch_empty_cache_steps is not None:
+ if not (isinstance(self.torch_empty_cache_steps, int) and self.torch_empty_cache_steps > 0):
+ raise ValueError(
+ f"`torch_empty_cache_steps` must be an integer bigger than 0, got {self.torch_empty_cache_steps}."
+ )
+
+ # eval_steps has to be defined and non-zero, fallbacks to logging_steps if the latter is non-zero
+ if self.eval_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0):
+ if self.logging_steps > 0:
+ logger.info(f"using `logging_steps` to initialize `eval_steps` to {self.logging_steps}")
+ self.eval_steps = self.logging_steps
+ else:
+ raise ValueError(
+ f"evaluation strategy {self.eval_strategy} requires either non-zero --eval_steps or"
+ " --logging_steps"
+ )
+
+ # logging_steps must be non-zero for logging_strategy that is other than 'no'
+ if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps == 0:
+ raise ValueError(f"logging strategy {self.logging_strategy} requires non-zero --logging_steps")
+
+ if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps > 1:
+ if self.logging_steps != int(self.logging_steps):
+ raise ValueError(f"--logging_steps must be an integer if bigger than 1: {self.logging_steps}")
+ self.logging_steps = int(self.logging_steps)
+ if self.eval_strategy == IntervalStrategy.STEPS and self.eval_steps > 1:
+ if self.eval_steps != int(self.eval_steps):
+ raise ValueError(f"--eval_steps must be an integer if bigger than 1: {self.eval_steps}")
+ self.eval_steps = int(self.eval_steps)
+ if self.save_strategy == SaveStrategy.STEPS and self.save_steps > 1:
+ if self.save_steps != int(self.save_steps):
+ raise ValueError(f"--save_steps must be an integer if bigger than 1: {self.save_steps}")
+ self.save_steps = int(self.save_steps)
+
+ # Sanity checks for load_best_model_at_end: we require save and eval strategies to be compatible.
+ if self.load_best_model_at_end and self.save_strategy != SaveStrategy.BEST:
+ if self.eval_strategy != self.save_strategy:
+ raise ValueError(
+ "--load_best_model_at_end requires the save and eval strategy to match, but found\n- Evaluation "
+ f"strategy: {self.eval_strategy}\n- Save strategy: {self.save_strategy}"
+ )
+ if self.eval_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0:
+ if self.eval_steps < 1 or self.save_steps < 1:
+ if not (self.eval_steps < 1 and self.save_steps < 1):
+ raise ValueError(
+ "--load_best_model_at_end requires the saving steps to be a multiple of the evaluation "
+ "steps, which cannot get guaranteed when mixing ratio and absolute steps for save_steps "
+ f"{self.save_steps} and eval_steps {self.eval_steps}."
+ )
+ # Work around floating point precision issues
+ LARGE_MULTIPLIER = 1_000_000
+ if (self.save_steps * LARGE_MULTIPLIER) % (self.eval_steps * LARGE_MULTIPLIER) != 0:
+ raise ValueError(
+ "--load_best_model_at_end requires the saving steps to be a multiple of the evaluation "
+ f"steps, but found {self.save_steps}, which is not a multiple of {self.eval_steps}."
+ )
+ else:
+ raise ValueError(
+ "--load_best_model_at_end requires the saving steps to be a round multiple of the evaluation "
+ f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}."
+ )
+
+ if not self.save_safetensors:
+ logger.info(
+ f"Found safetensors installation, but --save_safetensors={self.save_safetensors}. "
+ f"Safetensors should be a preferred weights saving format due to security and performance reasons. "
+ f"If your model cannot be saved by safetensors please feel free to open an issue at "
+ f"https://github.com/huggingface/safetensors!"
+ )
+
+ if (
+ self.load_best_model_at_end or self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU
+ ) and self.metric_for_best_model is None:
+ self.metric_for_best_model = "loss"
+ if self.greater_is_better is None and self.metric_for_best_model is not None:
+ self.greater_is_better = not self.metric_for_best_model.endswith("loss")
+ if self.framework == "pt" and is_torch_available():
+ if self.fp16_backend and self.fp16_backend != "auto":
+ warnings.warn(
+ "`fp16_backend` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
+ " `half_precision_backend` instead",
+ FutureWarning,
+ )
+ self.half_precision_backend = self.fp16_backend
+
+ if self.bf16 or self.bf16_full_eval:
+ if self.use_cpu and not is_torch_available() and not is_torch_xla_available():
+ # cpu
+ raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10")
+ elif not self.use_cpu:
+ if not is_torch_bf16_gpu_available() and not is_torch_xla_available(): # added for tpu support
+ error_message = "Your setup doesn't support bf16/gpu."
+ if is_torch_cuda_available():
+ error_message += " You need Ampere+ GPU with cuda>=11.0"
+ # gpu
+ raise ValueError(error_message)
+
+ if self.fp16 and self.bf16:
+ raise ValueError("At most one of fp16 and bf16 can be True, but not both")
+
+ if self.fp16_full_eval and self.bf16_full_eval:
+ raise ValueError("At most one of fp16 and bf16 can be True for full eval, but not both")
+
+ if self.bf16:
+ if self.half_precision_backend == "apex":
+ raise ValueError(" `--half_precision_backend apex`: GPU bf16 is not supported by apex.")
+
+ if self.half_precision_backend == "apex":
+ if not is_apex_available():
+ raise ImportError(
+ "Using FP16 with APEX but APEX is not installed, please refer to"
+ " https://www.github.com/nvidia/apex."
+ )
+ try:
+ from apex import amp # noqa: F401
+ except ImportError as e:
+ raise ImportError(
+ f"apex.amp is deprecated in the latest version of apex, causing this error {e}. Either revert to an older version or use pytorch amp by setting half_precision_backend='auto' instead. See https://github.com/NVIDIA/apex/pull/1896 "
+ )
+
+ if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
+ if self.eval_strategy == IntervalStrategy.NO:
+ raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy")
+ if not is_torch_available():
+ raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires torch>=0.2.0")
+
+ self.optim = OptimizerNames(self.optim)
+ if self.adafactor:
+ warnings.warn(
+ "`--adafactor` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--optim"
+ " adafactor` instead",
+ FutureWarning,
+ )
+ self.optim = OptimizerNames.ADAFACTOR
+
+ # We need to setup the accelerator config here *before* the first call to `self.device`
+ if is_accelerate_available():
+ if not isinstance(self.accelerator_config, AcceleratorConfig):
+ if self.accelerator_config is None:
+ self.accelerator_config = AcceleratorConfig()
+ elif isinstance(self.accelerator_config, dict):
+ self.accelerator_config = AcceleratorConfig(**self.accelerator_config)
+ # Check that a user didn't pass in the class instantiator
+ # such as `accelerator_config = AcceleratorConfig`
+ elif isinstance(self.accelerator_config, type):
+ raise NotImplementedError(
+ "Tried passing in a callable to `accelerator_config`, but this is not supported. "
+ "Please pass in a fully constructed `AcceleratorConfig` object instead."
+ )
+ else:
+ self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
+ if self.accelerator_config.split_batches:
+ logger.info(
+ "Using `split_batches=True` in `accelerator_config` will override the `per_device_train_batch_size` "
+ "Batches will be split across all processes equally when using `split_batches=True`."
+ )
+
+ # Initialize device before we proceed
+ if self.framework == "pt" and is_torch_available():
+ self.device
+
+ if self.torchdynamo is not None:
+ warnings.warn(
+ "`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
+ " `torch_compile_backend` instead",
+ FutureWarning,
+ )
+ self.torch_compile_backend = self.torchdynamo
+ if (self.torch_compile_mode is not None or self.torch_compile_backend is not None) and not self.torch_compile:
+ self.torch_compile = True
+ if self.torch_compile and self.torch_compile_backend is None:
+ if not self.use_cpu and is_torch_hpu_available():
+ self.torch_compile_backend = "hpu_backend"
+ else:
+ self.torch_compile_backend = "inductor"
+
+ # accelerate integration for torch compile
+ if self.torch_compile:
+ # set env vars for accelerate
+ prefix = "ACCELERATE_DYNAMO_"
+ os.environ[prefix + "BACKEND"] = self.torch_compile_backend
+ if self.torch_compile_mode is not None:
+ os.environ[prefix + "MODE"] = self.torch_compile_mode
+
+ if self.framework == "pt" and is_torch_available() and self.torch_compile:
+ if is_torch_tf32_available():
+ if self.tf32 is None and not self.fp16 or self.bf16:
+ device_str = "MUSA" if is_torch_musa_available() else "CUDA"
+ logger.info(
+ f"Setting TF32 in {device_str} backends to speedup torch compile, you won't see any improvement"
+ " otherwise."
+ )
+ if is_torch_musa_available():
+ torch.backends.mudnn.allow_tf32 = True
+ else:
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+ else:
+ logger.warning(
+ "The speedups for torchdynamo mostly come with GPU Ampere or higher and which is not detected here."
+ )
+ if self.framework == "pt" and is_torch_available() and self.tf32 is not None:
+ if self.tf32:
+ if is_torch_tf32_available():
+ if is_torch_musa_available():
+ torch.backends.mudnn.allow_tf32 = True
+ else:
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+ else:
+ raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")
+ else:
+ if is_torch_tf32_available():
+ if is_torch_musa_available():
+ torch.backends.mudnn.allow_tf32 = False
+ else:
+ torch.backends.cuda.matmul.allow_tf32 = False
+ torch.backends.cudnn.allow_tf32 = False
+ # no need to assert on else
+
+ # if training args is specified, it will override the one specified in the accelerate config
+ mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
+ if self.fp16:
+ mixed_precision_dtype = "fp16"
+ elif self.bf16:
+ mixed_precision_dtype = "bf16"
+ os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype
+
+ if self.report_to is None:
+ logger.info(
+ "The default value for the training argument `--report_to` will change in v5 (from all installed "
+ "integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as "
+ "now. You should start updating your code and make this info disappear :-)."
+ )
+ self.report_to = "all"
+ if self.report_to == "all" or self.report_to == ["all"]:
+ # Import at runtime to avoid a circular import.
+ from .integrations import get_available_reporting_integrations
+
+ self.report_to = get_available_reporting_integrations()
+
+ if "codecarbon" in self.report_to and torch.version.hip:
+ logger.warning(
+ "When using the Trainer, CodeCarbonCallback requires the `codecarbon` package, which is not compatible with AMD ROCm (https://github.com/mlco2/codecarbon/pull/490). Automatically disabling the codecarbon callback. Reference: https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments.report_to."
+ )
+ self.report_to.remove("codecarbon")
+
+ elif self.report_to == "none" or self.report_to == ["none"]:
+ self.report_to = []
+ elif not isinstance(self.report_to, list):
+ self.report_to = [self.report_to]
+
+ if self.warmup_ratio < 0 or self.warmup_ratio > 1:
+ raise ValueError("warmup_ratio must lie in range [0,1]")
+ elif self.warmup_ratio > 0 and self.warmup_steps > 0:
+ logger.info(
+ "Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio"
+ " during training"
+ )
+
+ if not isinstance(self.warmup_steps, int) or self.warmup_steps < 0:
+ raise ValueError("warmup_steps must be of type int and must be 0 or a positive integer.")
+
+ if self.fsdp is None:
+ self.fsdp = []
+ elif self.fsdp is True:
+ self.fsdp = [FSDPOption.FULL_SHARD]
+ elif isinstance(self.fsdp, str):
+ self.fsdp = [FSDPOption(s) for s in self.fsdp.split()]
+
+ if self.fsdp == [FSDPOption.OFFLOAD]:
+ raise ValueError(
+ "`--fsdp offload` can't work on its own. It needs to be added to `--fsdp full_shard` or "
+ '`--fsdp shard_grad_op`. For example, `--fsdp "full_shard offload"`.'
+ )
+ elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp:
+ raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.")
+
+ if self.gradient_checkpointing and (
+ FSDPOption.FULL_SHARD in self.fsdp or FSDPOption.HYBRID_SHARD in self.fsdp
+ ):
+ logger.warning(
+ "When using FSDP full shard, instead of using `gradient_checkpointing` in TrainingArguments, please"
+ " use `activation_checkpointing` in `fsdp_config`. The former introduces a redundant AllGather"
+ " operation in backward pass. Reference: https://github.com/huggingface/transformers/issues/30404"
+ )
+
+ if self.fsdp_config is None:
+ self.fsdp_config = {}
+
+ if isinstance(self.fsdp_config, str):
+ if len(self.fsdp) == 0:
+ warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.")
+ with open(self.fsdp_config, encoding="utf-8") as f:
+ self.fsdp_config = json.load(f)
+
+ if self.fsdp_config is not None and isinstance(self.fsdp_config, dict):
+ for k in list(self.fsdp_config.keys()):
+ if k.startswith("fsdp_"):
+ v = self.fsdp_config.pop(k)
+ self.fsdp_config[k[5:]] = v
+
+ if self.fsdp_min_num_params > 0:
+ warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning)
+
+ self.fsdp_config["min_num_params"] = max(self.fsdp_config.get("min_num_params", 0), self.fsdp_min_num_params)
+
+ # if fsdp_config["transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object
+ if isinstance(self.fsdp_config.get("transformer_layer_cls_to_wrap", None), str):
+ self.fsdp_config["transformer_layer_cls_to_wrap"] = [self.fsdp_config["transformer_layer_cls_to_wrap"]]
+
+ if self.fsdp_transformer_layer_cls_to_wrap is not None:
+ warnings.warn(
+ "using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead ", FutureWarning
+ )
+ self.fsdp_config["transformer_layer_cls_to_wrap"] = self.fsdp_config.get(
+ "transformer_layer_cls_to_wrap", []
+ ) + [self.fsdp_transformer_layer_cls_to_wrap]
+
+ if len(self.fsdp) == 0 and self.fsdp_config["min_num_params"] > 0:
+ warnings.warn("`min_num_params` is useful only when `--fsdp` is specified.")
+
+ if len(self.fsdp) == 0 and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None:
+ warnings.warn("`transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")
+
+ if (
+ len(self.fsdp) > 0
+ and self.fsdp_config["min_num_params"] > 0
+ and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None
+ ):
+ raise ValueError("`min_num_params` and `transformer_layer_cls_to_wrap` are mutually exclusive.")
+ self.fsdp_config["xla"] = self.fsdp_config.get("xla", False)
+ self.fsdp_config["xla_fsdp_v2"] = self.fsdp_config.get("xla_fsdp_v2", False)
+ self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False)
+ if self.fsdp_config["xla"]:
+ if len(self.fsdp) > 0:
+ # store XLA fsdp configuration parameters into a dictionary
+ # Copy the config to avoid modifying the original config (which may be used for JSON serialization)
+ self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {}).copy()
+ # apply appropriate string to torch.dtype conversions for parameters
+ if "compute_dtype" in self.xla_fsdp_config:
+ self.xla_fsdp_config["compute_dtype"] = getattr(torch, self.xla_fsdp_config["compute_dtype"])
+ if "buffer_dtype" in self.xla_fsdp_config:
+ self.xla_fsdp_config["buffer_dtype"] = getattr(torch, self.xla_fsdp_config["buffer_dtype"])
+ else:
+ warnings.warn("XLA FSDP can be used only when `--fsdp` is specified.")
+ else:
+ if self.fsdp_config["xla_fsdp_grad_ckpt"]:
+ warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")
+
+ # accelerate integration for FSDP
+ if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
+ os.environ["ACCELERATE_USE_FSDP"] = "true"
+ from accelerate.utils.constants import (
+ FSDP_AUTO_WRAP_POLICY,
+ FSDP_SHARDING_STRATEGY,
+ )
+
+ prefix = "FSDP_"
+ for fsdp_option in self.fsdp:
+ if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
+ # set environment variable for FSDP sharding strategy
+ os.environ[f"{prefix}SHARDING_STRATEGY"] = str(
+ FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1
+ )
+ elif fsdp_option == FSDPOption.OFFLOAD:
+ os.environ[f"{prefix}OFFLOAD_PARAMS"] = "true"
+ elif fsdp_option == FSDPOption.AUTO_WRAP:
+ os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
+ if self.fsdp_config["min_num_params"] > 0:
+ os.environ[f"{prefix}MIN_NUM_PARAMS"] = str(self.fsdp_config["min_num_params"])
+ os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
+ elif self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None:
+ os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"] = ",".join(
+ self.fsdp_config["transformer_layer_cls_to_wrap"]
+ )
+ prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH")
+ os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper()
+ os.environ[f"{prefix}FORWARD_PREFETCH"] = str(self.fsdp_config.get("forward_prefetch", "false")).lower()
+
+ sync_module_states = str(self.fsdp_config.get("sync_module_states", "true")).lower()
+ cpu_ram_efficient_loading = str(self.fsdp_config.get("cpu_ram_efficient_loading", "false")).lower()
+
+ if sync_module_states == "false" and cpu_ram_efficient_loading == "true":
+ # In this case, all the processes except the main process would have random weights leading
+ # to unexpected behaviour during training, thus throwing error here to prevent it.
+ raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`')
+
+ os.environ[f"{prefix}SYNC_MODULE_STATES"] = sync_module_states
+ os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading
+
+ os.environ[f"{prefix}USE_ORIG_PARAMS"] = str(self.fsdp_config.get("use_orig_params", "true")).lower()
+
+ if self.tpu_metrics_debug:
+ warnings.warn(
+ "using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
+ " `--debug tpu_metrics_debug` instead",
+ FutureWarning,
+ )
+ if self.debug is None:
+ self.debug = " tpu_metrics_debug"
+ else:
+ self.debug += " tpu_metrics_debug"
+ self.tpu_metrics_debug = False
+
+ if isinstance(self.debug, str):
+ self.debug = [DebugOption(s) for s in self.debug.split()]
+ elif self.debug is None:
+ self.debug = []
+
+ self.deepspeed_plugin = None
+ if self.deepspeed:
+ # - must be run very last in arg parsing, since it will use a lot of these settings.
+ # - must be run before the model is created.
+ if not is_accelerate_available():
+ raise ValueError(
+ f"--deepspeed requires Accelerate to be installed: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`."
+ )
+ from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
+
+ # will be used later by the Trainer
+ # note: leave self.deepspeed unmodified in case a user relies on it not to be modified)
+ self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed)
+ self.hf_deepspeed_config.trainer_config_process(self)
+
+ # Accelerate DeepSpeed Plugin
+ from accelerate.utils import DeepSpeedPlugin
+
+ os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
+ self.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config)
+ elif strtobool(os.environ.get("ACCELERATE_USE_DEEPSPEED", "false")):
+ # Accelerate DeepSpeed Plugin
+ from accelerate.utils import DeepSpeedPlugin
+
+ self.deepspeed_plugin = DeepSpeedPlugin()
+ mixed_precision = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
+ self.deepspeed_plugin.set_mixed_precision(mixed_precision)
+ self.deepspeed_plugin.set_deepspeed_weakref()
+
+ # Set mixed precision environment variable after DeepSpeed processing
+ # This ensures DeepSpeed config overrides have been applied to fp16/bf16 settings
+ if self.half_precision_backend != "apex":
+ mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
+ if self.fp16:
+ mixed_precision_dtype = "fp16"
+ elif self.bf16:
+ mixed_precision_dtype = "bf16"
+ os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype
+
+ if self.use_cpu:
+ self.dataloader_pin_memory = False
+
+ if self.dataloader_num_workers == 0 and self.dataloader_prefetch_factor is not None:
+ raise ValueError(
+ "--dataloader_prefetch_factor can only be set when data is loaded in a different process, i.e."
+ " when --dataloader_num_workers > 1."
+ )
+
+ if self.push_to_hub_token is not None:
+ warnings.warn(
+ "`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
+ "`--hub_token` instead.",
+ FutureWarning,
+ )
+ self.hub_token = self.push_to_hub_token
+
+ if self.push_to_hub_model_id is not None:
+ self.hub_model_id = get_full_repo_name(
+ self.push_to_hub_model_id, organization=self.push_to_hub_organization, token=self.hub_token
+ )
+ if self.push_to_hub_organization is not None:
+ warnings.warn(
+ "`--push_to_hub_model_id` and `--push_to_hub_organization` are deprecated and will be removed in "
+ "version 5 of 🤗 Transformers. Use `--hub_model_id` instead and pass the full repo name to this "
+ f"argument (in this case {self.hub_model_id}).",
+ FutureWarning,
+ )
+ else:
+ warnings.warn(
+ "`--push_to_hub_model_id` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
+ "`--hub_model_id` instead and pass the full repo name to this argument (in this case "
+ f"{self.hub_model_id}).",
+ FutureWarning,
+ )
+ elif self.push_to_hub_organization is not None:
+ self.hub_model_id = f"{self.push_to_hub_organization}/{Path(self.output_dir).name}"
+ warnings.warn(
+ "`--push_to_hub_organization` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
+ "`--hub_model_id` instead and pass the full repo name to this argument (in this case "
+ f"{self.hub_model_id}).",
+ FutureWarning,
+ )
+
+ if self.eval_use_gather_object and not is_accelerate_available("0.30.0"):
+ raise ValueError(
+ "--eval_use_gather_object requires Accelerate to be version of `accelerate` > 0.30.0."
+ "This is not supported and we recommend you to update your version."
+ )
+
+ if self.data_seed is not None:
+ if not is_accelerate_available("1.1.0"):
+ raise NotImplementedError(
+ "data_seed requires Accelerate version `accelerate` >= 1.1.0. "
+ "This is not supported and we recommend you to update your version."
+ )
+
+ if self.include_inputs_for_metrics:
+ logger.warning(
+ "Using `include_inputs_for_metrics` is deprecated and will be removed in version 5 of 🤗 Transformers. Please use `include_for_metrics` list argument instead."
+ )
+ self.include_for_metrics.append("inputs")
+
+ if self.include_num_input_tokens_seen is True:
+ self.include_num_input_tokens_seen = "all"
+ elif self.include_num_input_tokens_seen is False:
+ self.include_num_input_tokens_seen = "no"
+
+ def __str__(self):
+ self_as_dict = asdict(self)
+
+ # Remove deprecated arguments. That code should be removed once
+ # those deprecated arguments are removed from TrainingArguments. (TODO: v5)
+ del self_as_dict["per_gpu_train_batch_size"]
+ del self_as_dict["per_gpu_eval_batch_size"]
+
+ self_as_dict = {k: f"<{k.upper()}>" if k.endswith("_token") else v for k, v in self_as_dict.items()}
+
+ attrs_as_str = [f"{k}={v},\n" for k, v in sorted(self_as_dict.items())]
+ return f"{self.__class__.__name__}(\n{''.join(attrs_as_str)})"
+
+ __repr__ = __str__
+
+ @property
+ def train_batch_size(self) -> int:
+ """
+ The actual batch size for training (may differ from `per_gpu_train_batch_size` in distributed training).
+ """
+ if self.per_gpu_train_batch_size:
+ logger.warning(
+ "Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future "
+ "version. Using `--per_device_train_batch_size` is preferred."
+ )
+ per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size
+ train_batch_size = per_device_batch_size * max(1, self.n_gpu)
+ return train_batch_size
+
+ @property
+ def eval_batch_size(self) -> int:
+ """
+ The actual batch size for evaluation (may differ from `per_gpu_eval_batch_size` in distributed training).
+ """
+ if self.per_gpu_eval_batch_size:
+ logger.warning(
+ "Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future "
+ "version. Using `--per_device_eval_batch_size` is preferred."
+ )
+ per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size
+ eval_batch_size = per_device_batch_size * max(1, self.n_gpu)
+ return eval_batch_size
+
+ @property
+ def ddp_timeout_delta(self) -> timedelta:
+ """
+ The actual timeout for torch.distributed.init_process_group since it expects a timedelta variable.
+ """
+ return timedelta(seconds=self.ddp_timeout)
+
+ @cached_property
+ def _setup_devices(self) -> "torch.device":
+ requires_backends(self, ["torch"])
+ logger.info("PyTorch: setting up devices")
+ if not is_sagemaker_mp_enabled():
+ if not is_accelerate_available():
+ raise ImportError(
+ f"Using the `Trainer` with `PyTorch` requires `accelerate>={ACCELERATE_MIN_VERSION}`: "
+ f"Please run `pip install transformers[torch]` or `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
+ )
+ # We delay the init of `PartialState` to the end for clarity
+ accelerator_state_kwargs: dict[str, Any] = {"enabled": True, "use_configured_state": False}
+ if isinstance(self.accelerator_config, AcceleratorConfig):
+ accelerator_state_kwargs["use_configured_state"] = self.accelerator_config.pop(
+ "use_configured_state", False
+ )
+ if accelerator_state_kwargs["use_configured_state"]:
+ if PartialState._shared_state == {}:
+ raise ValueError(
+ "Passing `'use_configured_state':True` to the AcceleratorConfig requires a pre-configured "
+ "`AcceleratorState` or `PartialState` to be defined before calling `TrainingArguments`. "
+ )
+ # We rely on `PartialState` to yell if there's issues here (which it will)
+ self.distributed_state = PartialState(cpu=self.use_cpu)
+ if self.deepspeed and self.distributed_state.distributed_type != DistributedType.DEEPSPEED:
+ raise RuntimeError(
+ "Tried to use an already configured `Accelerator` or `PartialState` that was not initialized for DeepSpeed, "
+ "but also passed in a `deepspeed` configuration to the `TrainingArguments`. Please set "
+ "`use_configured_state:False` instead or setup your `Accelerator` or `PartialState` properly."
+ )
+ else:
+ AcceleratorState._reset_state(reset_partial_state=True)
+ self.distributed_state = None
+ if "ACCELERATE_USE_IPEX" not in os.environ:
+ os.environ["ACCELERATE_USE_IPEX"] = "false"
+
+ self._n_gpu = 1
+ if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")):
+ accelerator_state_kwargs["cpu"] = True
+ accelerator_state_kwargs["backend"] = self.ddp_backend
+ self._n_gpu = 0
+ elif is_sagemaker_mp_enabled():
+ accelerator_state_kwargs["enabled"] = False
+ local_rank = smp.local_rank()
+ device = torch.device("cuda", local_rank)
+ torch.cuda.set_device(device)
+ elif is_sagemaker_dp_enabled():
+ accelerator_state_kwargs["_use_sagemaker_dp"] = True
+ elif self.deepspeed:
+ accelerator_state_kwargs["use_deepspeed"] = True
+ accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout)
+ else:
+ accelerator_state_kwargs["backend"] = self.ddp_backend
+ accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout)
+
+ # Now we pop everything
+ if accelerator_state_kwargs.pop("enabled", False) and not accelerator_state_kwargs.pop(
+ "use_configured_state", False
+ ):
+ # We need to patch this env var when enabling to detect deepspeed
+ use_deepspeed = accelerator_state_kwargs.pop("use_deepspeed", False)
+ if use_deepspeed:
+ os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
+ self.distributed_state = PartialState(**accelerator_state_kwargs)
+ if use_deepspeed:
+ del os.environ["ACCELERATE_USE_DEEPSPEED"]
+ if not is_sagemaker_mp_enabled():
+ device = self.distributed_state.device
+ self.local_rank = self.distributed_state.local_process_index
+ if dist.is_available() and dist.is_initialized() and self.parallel_mode != ParallelMode.DISTRIBUTED:
+ logger.warning(
+ "torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. "
+ "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
+ )
+ if is_torch_xla_available():
+ device = self.distributed_state.device
+ self._n_gpu = 0
+ elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled():
+ # Already set _n_gpu
+ pass
+ elif self.distributed_state.distributed_type == DistributedType.NO:
+ if self.use_mps_device:
+ warnings.warn(
+ "`use_mps_device` is deprecated and will be removed in version 5.0 of 🤗 Transformers. "
+ "`mps` device will be used by default if available similar to the way `cuda` device is used."
+ "Therefore, no action from user is required. "
+ )
+ if device.type != "mps":
+ raise ValueError(
+ "Either you do not have an MPS-enabled device on this machine or MacOS version is not 12.3+ "
+ "or current PyTorch install was not built with MPS enabled."
+ )
+ if self.use_cpu:
+ device = torch.device("cpu")
+ elif is_torch_mps_available():
+ device = torch.device("mps")
+ elif is_torch_xpu_available():
+ if not is_ipex_available() and not is_accelerate_available("0.32.0.dev"):
+ raise ImportError("Using the XPU PyTorch backend requires `accelerate>=0.32.0.dev`")
+ device = torch.device("xpu:0")
+ torch.xpu.set_device(device)
+ elif is_torch_mlu_available():
+ device = torch.device("mlu:0")
+ torch.mlu.set_device(device)
+ elif is_torch_musa_available():
+ device = torch.device("musa:0")
+ torch.musa.set_device(device)
+ elif is_torch_npu_available():
+ device = torch.device("npu:0")
+ torch.npu.set_device(device)
+ elif is_torch_hpu_available():
+ device = torch.device("hpu:0")
+ torch.hpu.set_device(device)
+ else:
+ # if n_gpu is > 1 we'll use nn.DataParallel.
+ # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
+ # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will
+ # trigger an error that a device index is missing. Index 0 takes into account the
+ # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
+ # will use the first GPU in that env, i.e. GPU#1
+ device = torch.device(
+ "cuda:0" if torch.cuda.is_available() else os.environ.get("ACCELERATE_TORCH_DEVICE", "cpu")
+ )
+ # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
+ # the default value.
+ self._n_gpu = torch.cuda.device_count()
+ if device.type == "cuda":
+ torch.cuda.set_device(device)
+ return device
+
+ @property
+ def device(self) -> "torch.device":
+ """
+ The device used by this process.
+ """
+ requires_backends(self, ["torch"])
+ return self._setup_devices
+
+ @property
+ def n_gpu(self):
+ """
+ The number of GPUs used by this process.
+
+ Note:
+ This will only be greater than one when you have multiple GPUs available but are not using distributed
+ training. For distributed training, it will always be 1.
+ """
+ requires_backends(self, ["torch"])
+ # Make sure `self._n_gpu` is properly setup.
+ if not hasattr(self, "_n_gpu"):
+ _ = self._setup_devices
+ return self._n_gpu
+
+ @property
+ def parallel_mode(self):
+ """
+ The current mode used for parallelism if multiple GPUs/TPU cores are available. One of:
+
+ - `ParallelMode.NOT_PARALLEL`: no parallelism (CPU or one GPU).
+ - `ParallelMode.NOT_DISTRIBUTED`: several GPUs in one single process (uses `torch.nn.DataParallel`).
+ - `ParallelMode.DISTRIBUTED`: several GPUs, each having its own process (uses
+ `torch.nn.DistributedDataParallel`).
+ - `ParallelMode.TPU`: several TPU cores.
+ """
+ requires_backends(self, ["torch"])
+ if is_torch_xla_available():
+ return ParallelMode.TPU
+ elif is_sagemaker_mp_enabled():
+ return ParallelMode.SAGEMAKER_MODEL_PARALLEL
+ elif is_sagemaker_dp_enabled():
+ return ParallelMode.SAGEMAKER_DATA_PARALLEL
+ elif (
+ self.distributed_state is not None and self.distributed_state.distributed_type != DistributedType.NO
+ ) or (self.distributed_state is None and self.local_rank != -1):
+ return ParallelMode.DISTRIBUTED
+ elif self.n_gpu > 1:
+ return ParallelMode.NOT_DISTRIBUTED
+ else:
+ return ParallelMode.NOT_PARALLEL
+
+ @property
+ def world_size(self):
+ """
+ The number of processes used in parallel.
+ """
+ requires_backends(self, ["torch"])
+ if self.distributed_state is not None:
+ return self.distributed_state.num_processes
+ elif is_sagemaker_mp_enabled():
+ return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size()
+ return 1
+
+ @property
+ def process_index(self):
+ """
+ The index of the current process used.
+ """
+ requires_backends(self, ["torch"])
+ if self.distributed_state is not None:
+ return self.distributed_state.process_index
+ elif is_sagemaker_mp_enabled():
+ return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank()
+ return 0
+
+ @property
+ def local_process_index(self):
+ """
+ The index of the local process used.
+ """
+ requires_backends(self, ["torch"])
+
+ if self.distributed_state is not None:
+ return self.distributed_state.local_process_index
+ elif is_sagemaker_mp_enabled():
+ return smp.local_rank()
+ return 0
+
+ @property
+ def should_log(self):
+ """
+ Whether or not the current process should produce log.
+ """
+ if self.log_on_each_node:
+ return self.local_process_index == 0
+ else:
+ if is_sagemaker_mp_enabled():
+ return smp.rank() == 0
+ else:
+ return self.process_index == 0
+
+ @property
+ def should_save(self):
+ """
+ Whether or not the current process should write to disk, e.g., to save models and checkpoints.
+ """
+ if self.save_on_each_node:
+ return self.local_process_index == 0
+ else:
+ if is_sagemaker_mp_enabled():
+ return smp.rank() == 0
+ else:
+ return self.process_index == 0
+
+ def get_process_log_level(self):
+ """
+ Returns the log level to be used depending on whether this process is the main process of node 0, main process
+ of node non-0, or a non-main process.
+
+ For the main process the log level defaults to the logging level set (`logging.WARNING` if you didn't do
+ anything) unless overridden by `log_level` argument.
+
+ For the replica processes the log level defaults to `logging.WARNING` unless overridden by `log_level_replica`
+ argument.
+
+ The choice between the main and replica process settings is made according to the return value of `should_log`.
+ """
+
+ # convert to int
+ log_level = trainer_log_levels[self.log_level]
+ log_level_replica = trainer_log_levels[self.log_level_replica]
+
+ log_level_main_node = logging.get_verbosity() if log_level == -1 else log_level
+ log_level_replica_node = logging.get_verbosity() if log_level_replica == -1 else log_level_replica
+ return log_level_main_node if self.should_log else log_level_replica_node
+
+ @property
+ def place_model_on_device(self):
+ """
+ Can be subclassed and overridden for some specific integrations.
+ """
+ return not is_sagemaker_mp_enabled()
+
+ @property
+ def _no_sync_in_gradient_accumulation(self):
+ """
+ Whether or not to use no_sync for the gradients when doing gradient accumulation.
+ """
+ return not (
+ self.deepspeed or is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled() or is_torch_neuroncore_available()
+ )
+
+ @contextlib.contextmanager
+ def main_process_first(self, local=True, desc="work"):
+ """
+ A context manager for torch distributed environment where on needs to do something on the main process, while
+ blocking replicas, and when it's finished releasing the replicas.
+
+ One such use is for `datasets`'s `map` feature which to be efficient should be run once on the main process,
+ which upon completion saves a cached version of results and which then automatically gets loaded by the
+ replicas.
+
+ Args:
+ local (`bool`, *optional*, defaults to `True`):
+ if `True` first means process of rank 0 of each node if `False` first means process of rank 0 of node
+ rank 0 In multi-node environment with a shared filesystem you most likely will want to use
+ `local=False` so that only the main process of the first node will do the processing. If however, the
+ filesystem is not shared, then the main process of each node will need to do the processing, which is
+ the default behavior.
+ desc (`str`, *optional*, defaults to `"work"`):
+ a work description to be used in debug logs
+
+ """
+ if is_torch_available() and self.world_size > 1:
+ main_process_desc = "main local process" if local else "main process"
+ if self.distributed_state is not None:
+ is_main_process = (
+ self.distributed_state.is_local_main_process if local else self.distributed_state.is_main_process
+ )
+ elif is_sagemaker_mp_enabled():
+ is_main_process = smp.rank() == 0
+
+ try:
+ if not is_main_process:
+ # tell all replicas to wait
+ logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}")
+
+ if is_torch_xla_available():
+ xm.rendezvous(desc)
+ else:
+ dist.barrier()
+ yield
+ finally:
+ if is_main_process:
+ # the wait is over
+ logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas")
+ if is_torch_xla_available():
+ xm.rendezvous(desc)
+ else:
+ dist.barrier()
+ else:
+ yield
+
+ def get_warmup_steps(self, num_training_steps: int):
+ """
+ Get number of steps used for a linear warmup.
+ """
+ warmup_steps = (
+ self.warmup_steps if self.warmup_steps > 0 else math.ceil(num_training_steps * self.warmup_ratio)
+ )
+ return warmup_steps
+
+ def _dict_dtype_to_str(self, d: dict[str, Any]) -> None:
+ """
+ Checks whether the passed dictionary and its nested dicts have a *dtype* key and if it's not None,
+ converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
+ string, which can then be stored in the json format.
+ """
+ if d.get("dtype") is not None and not isinstance(d["dtype"], str):
+ d["dtype"] = str(d["dtype"]).split(".")[1]
+ for value in d.values():
+ if isinstance(value, dict):
+ self._dict_dtype_to_str(value)
+
+ def to_dict(self):
+ """
+ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
+ the token values by removing their value.
+ """
+ # filter out fields that are defined as field(init=False)
+ d = {field.name: getattr(self, field.name) for field in fields(self) if field.init}
+
+ for k, v in d.items():
+ if isinstance(v, Enum):
+ d[k] = v.value
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
+ d[k] = [x.value for x in v]
+ if k.endswith("_token"):
+ d[k] = f"<{k.upper()}>"
+ # Handle the accelerator_config if passed
+ if is_accelerate_available() and isinstance(v, AcceleratorConfig):
+ d[k] = v.to_dict()
+ # Handle the quantization_config if passed
+ if k == "model_init_kwargs" and isinstance(v, dict) and "quantization_config" in v:
+ quantization_config = v.get("quantization_config")
+ if quantization_config and not isinstance(quantization_config, dict):
+ d[k]["quantization_config"] = quantization_config.to_dict()
+ if k == "parallelism_config" and v is not None:
+ d[k] = v.to_json()
+
+ self._dict_dtype_to_str(d)
+
+ return d
+
+ def to_json_string(self):
+ """
+ Serializes this instance to a JSON string.
+ """
+ return json.dumps(self.to_dict(), indent=2)
+
+ def to_sanitized_dict(self) -> dict[str, Any]:
+ """
+ Sanitized serialization to use with TensorBoard's hparams
+ """
+ d = self.to_dict()
+ d = {**d, **{"train_batch_size": self.train_batch_size, "eval_batch_size": self.eval_batch_size}}
+
+ valid_types = [bool, int, float, str]
+ if is_torch_available():
+ valid_types.append(torch.Tensor)
+
+ return {k: v if type(v) in valid_types else str(v) for k, v in d.items()}
+
+ # The following methods are there to simplify the instantiation of `TrainingArguments`
+ def set_training(
+ self,
+ learning_rate: float = 5e-5,
+ batch_size: int = 8,
+ weight_decay: float = 0,
+ num_epochs: float = 3,
+ max_steps: int = -1,
+ gradient_accumulation_steps: int = 1,
+ seed: int = 42,
+ gradient_checkpointing: bool = False,
+ ):
+ """
+ A method that regroups all basic arguments linked to the training.
+
+
+
+ Calling this method will automatically set `self.do_train` to `True`.
+
+
+
+ Args:
+ learning_rate (`float`, *optional*, defaults to 5e-5):
+ The initial learning rate for the optimizer.
+ batch_size (`int` *optional*, defaults to 8):
+ The batch size per device (GPU/TPU core/CPU...) used for training.
+ weight_decay (`float`, *optional*, defaults to 0):
+ The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights in the
+ optimizer.
+ num_train_epochs(`float`, *optional*, defaults to 3.0):
+ Total number of training epochs to perform (if not an integer, will perform the decimal part percents
+ of the last epoch before stopping training).
+ max_steps (`int`, *optional*, defaults to -1):
+ If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`.
+ For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until
+ `max_steps` is reached.
+ gradient_accumulation_steps (`int`, *optional*, defaults to 1):
+ Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
+
+
+
+ When using gradient accumulation, one step is counted as one step with backward pass. Therefore,
+ logging, evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training
+ examples.
+
+
+
+ seed (`int`, *optional*, defaults to 42):
+ Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use
+ the [`~Trainer.model_init`] function to instantiate the model if it has some randomly initialized
+ parameters.
+ gradient_checkpointing (`bool`, *optional*, defaults to `False`):
+ If True, use gradient checkpointing to save memory at the expense of slower backward pass.
+
+ Example:
+
+ ```py
+ >>> from transformers import TrainingArguments
+
+ >>> args = TrainingArguments("working_dir")
+ >>> args = args.set_training(learning_rate=1e-4, batch_size=32)
+ >>> args.learning_rate
+ 1e-4
+ ```
+ """
+ self.do_train = True
+ self.learning_rate = learning_rate
+ self.per_device_train_batch_size = batch_size
+ self.weight_decay = weight_decay
+ self.num_train_epochs = num_epochs
+ self.max_steps = max_steps
+ self.gradient_accumulation_steps = gradient_accumulation_steps
+ self.seed = seed
+ self.gradient_checkpointing = gradient_checkpointing
+ return self
+
+ def set_evaluate(
+ self,
+ strategy: Union[str, IntervalStrategy] = "no",
+ steps: int = 500,
+ batch_size: int = 8,
+ accumulation_steps: Optional[int] = None,
+ delay: Optional[float] = None,
+ loss_only: bool = False,
+ jit_mode: bool = False,
+ ):
+ """
+ A method that regroups all arguments linked to evaluation.
+
+ Args:
+ strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`):
+ The evaluation strategy to adopt during training. Possible values are:
+
+ - `"no"`: No evaluation is done during training.
+ - `"steps"`: Evaluation is done (and logged) every `steps`.
+ - `"epoch"`: Evaluation is done at the end of each epoch.
+
+ Setting a `strategy` different from `"no"` will set `self.do_eval` to `True`.
+ steps (`int`, *optional*, defaults to 500):
+ Number of update steps between two evaluations if `strategy="steps"`.
+ batch_size (`int` *optional*, defaults to 8):
+ The batch size per device (GPU/TPU core/CPU...) used for evaluation.
+ accumulation_steps (`int`, *optional*):
+ Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU.
+ If left unset, the whole predictions are accumulated on GPU/TPU before being moved to the CPU (faster
+ but requires more memory).
+ delay (`float`, *optional*):
+ Number of epochs or steps to wait for before the first evaluation can be performed, depending on the
+ eval_strategy.
+ loss_only (`bool`, *optional*, defaults to `False`):
+ Ignores all outputs except the loss.
+ jit_mode (`bool`, *optional*):
+ Whether or not to use PyTorch jit trace for inference.
+
+ Example:
+
+ ```py
+ >>> from transformers import TrainingArguments
+
+ >>> args = TrainingArguments("working_dir")
+ >>> args = args.set_evaluate(strategy="steps", steps=100)
+ >>> args.eval_steps
+ 100
+ ```
+ """
+ self.eval_strategy = IntervalStrategy(strategy)
+ if self.eval_strategy == IntervalStrategy.STEPS and steps == 0:
+ raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.")
+ self.do_eval = self.eval_strategy != IntervalStrategy.NO
+ self.eval_steps = steps
+ self.per_device_eval_batch_size = batch_size
+ self.eval_accumulation_steps = accumulation_steps
+ self.eval_delay = delay
+ self.prediction_loss_only = loss_only
+ self.jit_mode_eval = jit_mode
+ return self
+
+ def set_testing(
+ self,
+ batch_size: int = 8,
+ loss_only: bool = False,
+ jit_mode: bool = False,
+ ):
+ """
+ A method that regroups all basic arguments linked to testing on a held-out dataset.
+
+
+
+ Calling this method will automatically set `self.do_predict` to `True`.
+
+
+
+ Args:
+ batch_size (`int` *optional*, defaults to 8):
+ The batch size per device (GPU/TPU core/CPU...) used for testing.
+ loss_only (`bool`, *optional*, defaults to `False`):
+ Ignores all outputs except the loss.
+ jit_mode (`bool`, *optional*):
+ Whether or not to use PyTorch jit trace for inference.
+
+ Example:
+
+ ```py
+ >>> from transformers import TrainingArguments
+
+ >>> args = TrainingArguments("working_dir")
+ >>> args = args.set_testing(batch_size=32)
+ >>> args.per_device_eval_batch_size
+ 32
+ ```
+ """
+ self.do_predict = True
+ self.per_device_eval_batch_size = batch_size
+ self.prediction_loss_only = loss_only
+ self.jit_mode_eval = jit_mode
+ return self
+
+ def set_save(
+ self,
+ strategy: Union[str, IntervalStrategy] = "steps",
+ steps: int = 500,
+ total_limit: Optional[int] = None,
+ on_each_node: bool = False,
+ ):
+ """
+ A method that regroups all arguments linked to checkpoint saving.
+
+ Args:
+ strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
+ The checkpoint save strategy to adopt during training. Possible values are:
+
+ - `"no"`: No save is done during training.
+ - `"epoch"`: Save is done at the end of each epoch.
+ - `"steps"`: Save is done every `save_steps`.
+
+ steps (`int`, *optional*, defaults to 500):
+ Number of updates steps before two checkpoint saves if `strategy="steps"`.
+ total_limit (`int`, *optional*):
+ If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
+ `output_dir`.
+ on_each_node (`bool`, *optional*, defaults to `False`):
+ When doing multi-node distributed training, whether to save models and checkpoints on each node, or
+ only on the main one.
+
+ This should not be activated when the different nodes use the same storage as the files will be saved
+ with the same names for each node.
+
+ Example:
+
+ ```py
+ >>> from transformers import TrainingArguments
+
+ >>> args = TrainingArguments("working_dir")
+ >>> args = args.set_save(strategy="steps", steps=100)
+ >>> args.save_steps
+ 100
+ ```
+ """
+ self.save_strategy = SaveStrategy(strategy)
+ if self.save_strategy == SaveStrategy.STEPS and steps == 0:
+ raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.")
+ self.save_steps = steps
+ self.save_total_limit = total_limit
+ self.save_on_each_node = on_each_node
+ return self
+
+ def set_logging(
+ self,
+ strategy: Union[str, IntervalStrategy] = "steps",
+ steps: int = 500,
+ report_to: Union[str, list[str]] = "none",
+ level: str = "passive",
+ first_step: bool = False,
+ nan_inf_filter: bool = False,
+ on_each_node: bool = False,
+ replica_level: str = "passive",
+ ):
+ """
+ A method that regroups all arguments linked to logging.
+
+ Args:
+ strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
+ The logging strategy to adopt during training. Possible values are:
+
+ - `"no"`: No logging is done during training.
+ - `"epoch"`: Logging is done at the end of each epoch.
+ - `"steps"`: Logging is done every `logging_steps`.
+
+ steps (`int`, *optional*, defaults to 500):
+ Number of update steps between two logs if `strategy="steps"`.
+ level (`str`, *optional*, defaults to `"passive"`):
+ Logger log level to use on the main process. Possible choices are the log levels as strings: `"debug"`,
+ `"info"`, `"warning"`, `"error"` and `"critical"`, plus a `"passive"` level which doesn't set anything
+ and lets the application set the level.
+ report_to (`str` or `list[str]`, *optional*, defaults to `"all"`):
+ The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
+ `"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`,
+ `"neptune"`, `"swanlab"`, `"tensorboard"`, `"trackio"` and `"wandb"`. Use `"all"` to report to all
+ integrations installed, `"none"` for no integrations.
+ first_step (`bool`, *optional*, defaults to `False`):
+ Whether to log and evaluate the first `global_step` or not.
+ nan_inf_filter (`bool`, *optional*, defaults to `True`):
+ Whether to filter `nan` and `inf` losses for logging. If set to `True` the loss of every step that is
+ `nan` or `inf` is filtered and the average loss of the current logging window is taken instead.
+
+
+
+ `nan_inf_filter` only influences the logging of loss values, it does not change the behavior the
+ gradient is computed or applied to the model.
+
+
+
+ on_each_node (`bool`, *optional*, defaults to `True`):
+ In multinode distributed training, whether to log using `log_level` once per node, or only on the main
+ node.
+ replica_level (`str`, *optional*, defaults to `"passive"`):
+ Logger log level to use on replicas. Same choices as `log_level`
+
+ Example:
+
+ ```py
+ >>> from transformers import TrainingArguments
+
+ >>> args = TrainingArguments("working_dir")
+ >>> args = args.set_logging(strategy="steps", steps=100)
+ >>> args.logging_steps
+ 100
+ ```
+ """
+ self.logging_strategy = IntervalStrategy(strategy)
+ if self.logging_strategy == IntervalStrategy.STEPS and steps == 0:
+ raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.")
+ self.logging_steps = steps
+ self.report_to = report_to
+ self.log_level = level
+ self.logging_first_step = first_step
+ self.logging_nan_inf_filter = nan_inf_filter
+ self.log_on_each_node = on_each_node
+ self.log_level_replica = replica_level
+ return self
+
+ def set_push_to_hub(
+ self,
+ model_id: str,
+ strategy: Union[str, HubStrategy] = "every_save",
+ token: Optional[str] = None,
+ private_repo: Optional[bool] = None,
+ always_push: bool = False,
+ revision: Optional[str] = None,
+ ):
+ """
+ A method that regroups all arguments linked to synchronizing checkpoints with the Hub.
+
+
+
+ Calling this method will set `self.push_to_hub` to `True`, which means the `output_dir` will begin a git
+ directory synced with the repo (determined by `model_id`) and the content will be pushed each time a save is
+ triggered (depending on your `self.save_strategy`). Calling [`~Trainer.save_model`] will also trigger a push.
+
+
+
+ Args:
+ model_id (`str`):
+ The name of the repository to keep in sync with the local *output_dir*. It can be a simple model ID in
+ which case the model will be pushed in your namespace. Otherwise it should be the whole repository
+ name, for instance `"user_name/model"`, which allows you to push to an organization you are a member of
+ with `"organization_name/model"`.
+ strategy (`str` or [`~trainer_utils.HubStrategy`], *optional*, defaults to `"every_save"`):
+ Defines the scope of what is pushed to the Hub and when. Possible values are:
+
+ - `"end"`: push the model, its configuration, the processing_class e.g. tokenizer (if passed along to the [`Trainer`]) and a
+ draft of a model card when the [`~Trainer.save_model`] method is called.
+ - `"every_save"`: push the model, its configuration, the processing_class e.g. tokenizer (if passed along to the [`Trainer`])
+ and
+ a draft of a model card each time there is a model save. The pushes are asynchronous to not block
+ training, and in case the save are very frequent, a new push is only attempted if the previous one is
+ finished. A last push is made with the final model at the end of training.
+ - `"checkpoint"`: like `"every_save"` but the latest checkpoint is also pushed in a subfolder named
+ last-checkpoint, allowing you to resume training easily with
+ `trainer.train(resume_from_checkpoint="last-checkpoint")`.
+ - `"all_checkpoints"`: like `"checkpoint"` but all checkpoints are pushed like they appear in the
+ output
+ folder (so you will get one checkpoint folder per folder in your final repository)
+
+ token (`str`, *optional*):
+ The token to use to push the model to the Hub. Will default to the token in the cache folder obtained
+ with `hf auth login`.
+ private_repo (`bool`, *optional*, defaults to `False`):
+ Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
+ always_push (`bool`, *optional*, defaults to `False`):
+ Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not
+ finished.
+ revision (`str`, *optional*):
+ The revision to use when pushing to the Hub. Can be a branch name, a tag, or a commit hash.
+
+ Example:
+
+ ```py
+ >>> from transformers import TrainingArguments
+
+ >>> args = TrainingArguments("working_dir")
+ >>> args = args.set_push_to_hub("me/awesome-model")
+ >>> args.hub_model_id
+ 'me/awesome-model'
+ ```
+ """
+ self.push_to_hub = True
+ self.hub_model_id = model_id
+ self.hub_strategy = HubStrategy(strategy)
+ self.hub_token = token
+ self.hub_private_repo = private_repo
+ self.hub_always_push = always_push
+ self.hub_revision = revision
+ return self
+
+ def set_optimizer(
+ self,
+ name: Union[str, OptimizerNames] = "adamw_torch",
+ learning_rate: float = 5e-5,
+ weight_decay: float = 0,
+ beta1: float = 0.9,
+ beta2: float = 0.999,
+ epsilon: float = 1e-8,
+ args: Optional[str] = None,
+ ):
+ """
+ A method that regroups all arguments linked to the optimizer and its hyperparameters.
+
+ Args:
+ name (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch"`):
+ The optimizer to use: `"adamw_torch"`, `"adamw_torch_fused"`, `"adamw_apex_fused"`,
+ `"adamw_anyprecision"` or `"adafactor"`.
+ learning_rate (`float`, *optional*, defaults to 5e-5):
+ The initial learning rate.
+ weight_decay (`float`, *optional*, defaults to 0):
+ The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights.
+ beta1 (`float`, *optional*, defaults to 0.9):
+ The beta1 hyperparameter for the adam optimizer or its variants.
+ beta2 (`float`, *optional*, defaults to 0.999):
+ The beta2 hyperparameter for the adam optimizer or its variants.
+ epsilon (`float`, *optional*, defaults to 1e-8):
+ The epsilon hyperparameter for the adam optimizer or its variants.
+ args (`str`, *optional*):
+ Optional arguments that are supplied to AnyPrecisionAdamW (only useful when
+ `optim="adamw_anyprecision"`).
+
+ Example:
+
+ ```py
+ >>> from transformers import TrainingArguments
+
+ >>> args = TrainingArguments("working_dir")
+ >>> args = args.set_optimizer(name="adamw_torch", beta1=0.8)
+ >>> args.optim
+ 'adamw_torch'
+ ```
+ """
+ self.optim = OptimizerNames(name)
+ self.learning_rate = learning_rate
+ self.weight_decay = weight_decay
+ self.adam_beta1 = beta1
+ self.adam_beta2 = beta2
+ self.adam_epsilon = epsilon
+ self.optim_args = args
+ return self
+
+ def set_lr_scheduler(
+ self,
+ name: Union[str, SchedulerType] = "linear",
+ num_epochs: float = 3.0,
+ max_steps: int = -1,
+ warmup_ratio: float = 0,
+ warmup_steps: int = 0,
+ ):
+ """
+ A method that regroups all arguments linked to the learning rate scheduler and its hyperparameters.
+
+ Args:
+ name (`str` or [`SchedulerType`], *optional*, defaults to `"linear"`):
+ The scheduler type to use. See the documentation of [`SchedulerType`] for all possible values.
+ num_epochs(`float`, *optional*, defaults to 3.0):
+ Total number of training epochs to perform (if not an integer, will perform the decimal part percents
+ of the last epoch before stopping training).
+ max_steps (`int`, *optional*, defaults to -1):
+ If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`.
+ For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until
+ `max_steps` is reached.
+ warmup_ratio (`float`, *optional*, defaults to 0.0):
+ Ratio of total training steps used for a linear warmup from 0 to `learning_rate`.
+ warmup_steps (`int`, *optional*, defaults to 0):
+ Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of
+ `warmup_ratio`.
+
+ Example:
+
+ ```py
+ >>> from transformers import TrainingArguments
+
+ >>> args = TrainingArguments("working_dir")
+ >>> args = args.set_lr_scheduler(name="cosine", warmup_ratio=0.05)
+ >>> args.warmup_ratio
+ 0.05
+ ```
+ """
+ self.lr_scheduler_type = SchedulerType(name)
+ self.num_train_epochs = num_epochs
+ self.max_steps = max_steps
+ self.warmup_ratio = warmup_ratio
+ self.warmup_steps = warmup_steps
+ return self
+
+ def set_dataloader(
+ self,
+ train_batch_size: int = 8,
+ eval_batch_size: int = 8,
+ drop_last: bool = False,
+ num_workers: int = 0,
+ pin_memory: bool = True,
+ persistent_workers: bool = False,
+ prefetch_factor: Optional[int] = None,
+ auto_find_batch_size: bool = False,
+ ignore_data_skip: bool = False,
+ sampler_seed: Optional[int] = None,
+ ):
+ """
+ A method that regroups all arguments linked to the dataloaders creation.
+
+ Args:
+ drop_last (`bool`, *optional*, defaults to `False`):
+ Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch
+ size) or not.
+ num_workers (`int`, *optional*, defaults to 0):
+ Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in
+ the main process.
+ pin_memory (`bool`, *optional*, defaults to `True`):
+ Whether you want to pin memory in data loaders or not. Will default to `True`.
+ persistent_workers (`bool`, *optional*, defaults to `False`):
+ If True, the data loader will not shut down the worker processes after a dataset has been consumed
+ once. This allows to maintain the workers Dataset instances alive. Can potentially speed up training,
+ but will increase RAM usage. Will default to `False`.
+ prefetch_factor (`int`, *optional*):
+ Number of batches loaded in advance by each worker.
+ 2 means there will be a total of 2 * num_workers batches prefetched across all workers.
+ auto_find_batch_size (`bool`, *optional*, defaults to `False`)
+ Whether to find a batch size that will fit into memory automatically through exponential decay,
+ avoiding CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
+ ignore_data_skip (`bool`, *optional*, defaults to `False`):
+ When resuming training, whether or not to skip the epochs and batches to get the data loading at the
+ same stage as in the previous training. If set to `True`, the training will begin faster (as that
+ skipping step can take a long time) but will not yield the same results as the interrupted training
+ would have.
+ sampler_seed (`int`, *optional*):
+ Random seed to be used with data samplers. If not set, random generators for data sampling will use the
+ same seed as `self.seed`. This can be used to ensure reproducibility of data sampling, independent of
+ the model seed.
+
+ Example:
+
+ ```py
+ >>> from transformers import TrainingArguments
+
+ >>> args = TrainingArguments("working_dir")
+ >>> args = args.set_dataloader(train_batch_size=16, eval_batch_size=64)
+ >>> args.per_device_train_batch_size
+ 16
+ ```
+ """
+ self.per_device_train_batch_size = train_batch_size
+ self.per_device_eval_batch_size = eval_batch_size
+ self.dataloader_drop_last = drop_last
+ self.dataloader_num_workers = num_workers
+ self.dataloader_pin_memory = pin_memory
+ self.dataloader_persistent_workers = persistent_workers
+ self.dataloader_prefetch_factor = prefetch_factor
+ self.auto_find_batch_size = auto_find_batch_size
+ self.ignore_data_skip = ignore_data_skip
+ self.data_seed = sampler_seed
+ return self
+
+
+class ParallelMode(Enum):
+ NOT_PARALLEL = "not_parallel"
+ NOT_DISTRIBUTED = "not_distributed"
+ DISTRIBUTED = "distributed"
+ SAGEMAKER_MODEL_PARALLEL = "sagemaker_model_parallel"
+ SAGEMAKER_DATA_PARALLEL = "sagemaker_data_parallel"
+ TPU = "tpu"
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/training_args_seq2seq.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/training_args_seq2seq.py
new file mode 100644
index 0000000000000000000000000000000000000000..5342b7add3932c542e35247e52920d8fc91ed325
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/training_args_seq2seq.py
@@ -0,0 +1,90 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Optional, Union
+
+from .generation.configuration_utils import GenerationConfig
+from .training_args import TrainingArguments
+from .utils import add_start_docstrings
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+@add_start_docstrings(TrainingArguments.__doc__)
+class Seq2SeqTrainingArguments(TrainingArguments):
+ """
+ Args:
+ predict_with_generate (`bool`, *optional*, defaults to `False`):
+ Whether to use generate to calculate generative metrics (ROUGE, BLEU).
+ generation_max_length (`int`, *optional*):
+ The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default to the
+ `max_length` value of the model configuration.
+ generation_num_beams (`int`, *optional*):
+ The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default to the
+ `num_beams` value of the model configuration.
+ generation_config (`str` or `Path` or [`~generation.GenerationConfig`], *optional*):
+ Allows to load a [`~generation.GenerationConfig`] from the `from_pretrained` method. This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~GenerationConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
+ - a [`~generation.GenerationConfig`] object.
+ """
+
+ sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."})
+ predict_with_generate: bool = field(
+ default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
+ )
+ generation_max_length: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default "
+ "to the `max_length` value of the model configuration."
+ )
+ },
+ )
+ generation_num_beams: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default "
+ "to the `num_beams` value of the model configuration."
+ )
+ },
+ )
+ generation_config: Optional[Union[str, Path, GenerationConfig]] = field(
+ default=None,
+ metadata={
+ "help": "Model id, file path or url pointing to a GenerationConfig json file, to use during prediction."
+ },
+ )
+
+ def to_dict(self):
+ """
+ Serializes this instance while replace `Enum` by their values and `GenerationConfig` by dictionaries (for JSON
+ serialization support). It obfuscates the token values by removing their value.
+ """
+ # filter out fields that are defined as field(init=False)
+ d = super().to_dict()
+ for k, v in d.items():
+ if isinstance(v, GenerationConfig):
+ d[k] = v.to_dict()
+ return d
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/training_args_tf.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/training_args_tf.py
new file mode 100644
index 0000000000000000000000000000000000000000..24763dabf9167896f69cd939ccec99f37d5fc0de
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/training_args_tf.py
@@ -0,0 +1,300 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from dataclasses import dataclass, field
+from functools import cached_property
+from typing import Optional
+
+from .training_args import TrainingArguments
+from .utils import is_tf_available, logging, requires_backends
+
+
+logger = logging.get_logger(__name__)
+
+if is_tf_available():
+ import tensorflow as tf
+
+ from .modeling_tf_utils import keras
+
+
+@dataclass
+class TFTrainingArguments(TrainingArguments):
+ """
+ TrainingArguments is the subset of the arguments we use in our example scripts **which relate to the training loop
+ itself**.
+
+ Using [`HfArgumentParser`] we can turn this class into
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
+ command line.
+
+ Parameters:
+ output_dir (`str`):
+ The output directory where the model predictions and checkpoints will be written.
+ overwrite_output_dir (`bool`, *optional*, defaults to `False`):
+ If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir`
+ points to a checkpoint directory.
+ do_train (`bool`, *optional*, defaults to `False`):
+ Whether to run training or not. This argument is not directly used by [`Trainer`], it's intended to be used
+ by your training/evaluation scripts instead. See the [example
+ scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
+ do_eval (`bool`, *optional*):
+ Whether to run evaluation on the validation set or not. Will be set to `True` if `eval_strategy` is
+ different from `"no"`. This argument is not directly used by [`Trainer`], it's intended to be used by your
+ training/evaluation scripts instead. See the [example
+ scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
+ do_predict (`bool`, *optional*, defaults to `False`):
+ Whether to run predictions on the test set or not. This argument is not directly used by [`Trainer`], it's
+ intended to be used by your training/evaluation scripts instead. See the [example
+ scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
+ eval_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`):
+ The evaluation strategy to adopt during training. Possible values are:
+
+ - `"no"`: No evaluation is done during training.
+ - `"steps"`: Evaluation is done (and logged) every `eval_steps`.
+ - `"epoch"`: Evaluation is done at the end of each epoch.
+
+ per_device_train_batch_size (`int`, *optional*, defaults to 8):
+ The batch size per GPU/TPU core/CPU for training.
+ per_device_eval_batch_size (`int`, *optional*, defaults to 8):
+ The batch size per GPU/TPU core/CPU for evaluation.
+ gradient_accumulation_steps (`int`, *optional*, defaults to 1):
+ Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
+
+
+
+ When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging,
+ evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training examples.
+
+
+
+ learning_rate (`float`, *optional*, defaults to 5e-5):
+ The initial learning rate for Adam.
+ weight_decay (`float`, *optional*, defaults to 0):
+ The weight decay to apply (if not zero).
+ adam_beta1 (`float`, *optional*, defaults to 0.9):
+ The beta1 hyperparameter for the Adam optimizer.
+ adam_beta2 (`float`, *optional*, defaults to 0.999):
+ The beta2 hyperparameter for the Adam optimizer.
+ adam_epsilon (`float`, *optional*, defaults to 1e-8):
+ The epsilon hyperparameter for the Adam optimizer.
+ max_grad_norm (`float`, *optional*, defaults to 1.0):
+ Maximum gradient norm (for gradient clipping).
+ num_train_epochs(`float`, *optional*, defaults to 3.0):
+ Total number of training epochs to perform.
+ max_steps (`int`, *optional*, defaults to -1):
+ If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`.
+ For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until
+ `max_steps` is reached.
+ warmup_ratio (`float`, *optional*, defaults to 0.0):
+ Ratio of total training steps used for a linear warmup from 0 to `learning_rate`.
+ warmup_steps (`int`, *optional*, defaults to 0):
+ Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`.
+ logging_dir (`str`, *optional*):
+ [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to
+ *runs/**CURRENT_DATETIME_HOSTNAME***.
+ logging_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
+ The logging strategy to adopt during training. Possible values are:
+
+ - `"no"`: No logging is done during training.
+ - `"epoch"`: Logging is done at the end of each epoch.
+ - `"steps"`: Logging is done every `logging_steps`.
+
+ logging_first_step (`bool`, *optional*, defaults to `False`):
+ Whether to log and evaluate the first `global_step` or not.
+ logging_steps (`int`, *optional*, defaults to 500):
+ Number of update steps between two logs if `logging_strategy="steps"`.
+ save_strategy (`str` or [`~trainer_utils.SaveStrategy`], *optional*, defaults to `"steps"`):
+ The checkpoint save strategy to adopt during training. Possible values are:
+
+ - `"no"`: No save is done during training.
+ - `"epoch"`: Save is done at the end of each epoch.
+ - `"steps"`: Save is done every `save_steps`.
+
+ save_steps (`int`, *optional*, defaults to 500):
+ Number of updates steps before two checkpoint saves if `save_strategy="steps"`.
+ save_total_limit (`int`, *optional*):
+ If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
+ `output_dir`.
+ no_cuda (`bool`, *optional*, defaults to `False`):
+ Whether to not use CUDA even when it is available or not.
+ seed (`int`, *optional*, defaults to 42):
+ Random seed that will be set at the beginning of training.
+ fp16 (`bool`, *optional*, defaults to `False`):
+ Whether to use 16-bit (mixed) precision training (through NVIDIA Apex) instead of 32-bit training.
+ fp16_opt_level (`str`, *optional*, defaults to 'O1'):
+ For `fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details on
+ the [Apex documentation](https://nvidia.github.io/apex/amp).
+ local_rank (`int`, *optional*, defaults to -1):
+ During distributed training, the rank of the process.
+ tpu_num_cores (`int`, *optional*):
+ When training on TPU, the number of TPU cores (automatically passed by launcher script).
+ debug (`bool`, *optional*, defaults to `False`):
+ Whether to activate the trace to record computation graphs and profiling information or not.
+ dataloader_drop_last (`bool`, *optional*, defaults to `False`):
+ Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
+ or not.
+ eval_steps (`int`, *optional*, defaults to 1000):
+ Number of update steps before two evaluations.
+ past_index (`int`, *optional*, defaults to -1):
+ Some models like [TransformerXL](../model_doc/transformerxl) or :doc*XLNet <../model_doc/xlnet>* can make
+ use of the past hidden states for their predictions. If this argument is set to a positive int, the
+ `Trainer` will use the corresponding output (usually index 2) as the past state and feed it to the model at
+ the next training step under the keyword argument `mems`.
+ tpu_name (`str`, *optional*):
+ The name of the TPU the process is running on.
+ tpu_zone (`str`, *optional*):
+ The zone of the TPU the process is running on. If not specified, we will attempt to automatically detect
+ from metadata.
+ gcp_project (`str`, *optional*):
+ Google Cloud Project name for the Cloud TPU-enabled project. If not specified, we will attempt to
+ automatically detect from metadata.
+ run_name (`str`, *optional*):
+ A descriptor for the run. Notably used for trackio, wandb, mlflow, comet and swanlab logging.
+ xla (`bool`, *optional*):
+ Whether to activate the XLA compilation or not.
+ """
+
+ framework = "tf"
+ tpu_name: Optional[str] = field(
+ default=None,
+ metadata={"help": "Name of TPU"},
+ )
+
+ tpu_zone: Optional[str] = field(
+ default=None,
+ metadata={"help": "Zone of TPU"},
+ )
+
+ gcp_project: Optional[str] = field(
+ default=None,
+ metadata={"help": "Name of Cloud TPU-enabled project"},
+ )
+
+ poly_power: float = field(
+ default=1.0,
+ metadata={"help": "Power for the Polynomial decay LR scheduler."},
+ )
+
+ xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"})
+
+ @cached_property
+ def _setup_strategy(self) -> tuple["tf.distribute.Strategy", int]:
+ requires_backends(self, ["tf"])
+ logger.info("Tensorflow: setting up strategy")
+
+ gpus = tf.config.list_physical_devices("GPU")
+
+ # Set to float16 at first
+ if self.fp16:
+ keras.mixed_precision.set_global_policy("mixed_float16")
+
+ if self.no_cuda:
+ strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
+ else:
+ try:
+ if self.tpu_name:
+ tpu = tf.distribute.cluster_resolver.TPUClusterResolver(
+ self.tpu_name, zone=self.tpu_zone, project=self.gcp_project
+ )
+ else:
+ tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
+ except ValueError:
+ if self.tpu_name:
+ raise RuntimeError(f"Couldn't connect to TPU {self.tpu_name}!")
+ else:
+ tpu = None
+
+ if tpu:
+ # Set to bfloat16 in case of TPU
+ if self.fp16:
+ keras.mixed_precision.set_global_policy("mixed_bfloat16")
+
+ tf.config.experimental_connect_to_cluster(tpu)
+ tf.tpu.experimental.initialize_tpu_system(tpu)
+
+ strategy = tf.distribute.TPUStrategy(tpu)
+
+ elif len(gpus) == 0:
+ strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
+ elif len(gpus) == 1:
+ strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
+ elif len(gpus) > 1:
+ # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
+ strategy = tf.distribute.MirroredStrategy()
+ else:
+ raise ValueError("Cannot find the proper strategy, please check your environment properties.")
+
+ return strategy
+
+ @property
+ def strategy(self) -> "tf.distribute.Strategy":
+ """
+ The strategy used for distributed training.
+ """
+ requires_backends(self, ["tf"])
+ return self._setup_strategy
+
+ @property
+ def n_replicas(self) -> int:
+ """
+ The number of replicas (CPUs, GPUs or TPU cores) used in this training.
+ """
+ requires_backends(self, ["tf"])
+ return self._setup_strategy.num_replicas_in_sync
+
+ @property
+ def should_log(self):
+ """
+ Whether or not the current process should produce log.
+ """
+ return False # TF Logging is handled by Keras not the Trainer
+
+ @property
+ def train_batch_size(self) -> int:
+ """
+ The actual batch size for training (may differ from `per_gpu_train_batch_size` in distributed training).
+ """
+ if self.per_gpu_train_batch_size:
+ logger.warning(
+ "Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future "
+ "version. Using `--per_device_train_batch_size` is preferred."
+ )
+ per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size
+ return per_device_batch_size * self.n_replicas
+
+ @property
+ def eval_batch_size(self) -> int:
+ """
+ The actual batch size for evaluation (may differ from `per_gpu_eval_batch_size` in distributed training).
+ """
+ if self.per_gpu_eval_batch_size:
+ logger.warning(
+ "Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future "
+ "version. Using `--per_device_eval_batch_size` is preferred."
+ )
+ per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size
+ return per_device_batch_size * self.n_replicas
+
+ @property
+ def n_gpu(self) -> int:
+ """
+ The number of replicas (CPUs, GPUs or TPU cores) used in this training.
+ """
+ requires_backends(self, ["tf"])
+ warnings.warn(
+ "The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead.",
+ FutureWarning,
+ )
+ return self._setup_strategy.num_replicas_in_sync
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/video_processing_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/video_processing_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bc81bf8eb28f4d668077ac3574ebf14e6f3fefe
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/video_processing_utils.py
@@ -0,0 +1,895 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import warnings
+from copy import deepcopy
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+
+from .dynamic_module_utils import custom_object_save
+from .image_processing_utils import (
+ BatchFeature,
+ get_size_dict,
+)
+from .image_processing_utils_fast import BaseImageProcessorFast
+from .image_utils import (
+ ChannelDimension,
+ SizeDict,
+ validate_kwargs,
+)
+from .processing_utils import Unpack, VideosKwargs
+from .utils import (
+ IMAGE_PROCESSOR_NAME,
+ PROCESSOR_NAME,
+ VIDEO_PROCESSOR_NAME,
+ TensorType,
+ add_start_docstrings,
+ copy_func,
+ download_url,
+ is_offline_mode,
+ is_remote_url,
+ is_torch_available,
+ is_torchcodec_available,
+ is_torchvision_v2_available,
+ logging,
+)
+from .utils.hub import cached_file
+from .utils.import_utils import requires
+from .video_utils import (
+ VideoInput,
+ VideoMetadata,
+ group_videos_by_shape,
+ is_valid_video,
+ load_video,
+ make_batched_metadata,
+ make_batched_videos,
+ reorder_videos,
+ to_channel_dimension_format,
+)
+
+
+if is_torch_available():
+ import torch
+
+if is_torchvision_v2_available():
+ from torchvision.transforms.v2 import functional as F
+
+
+logger = logging.get_logger(__name__)
+
+
+BASE_VIDEO_PROCESSOR_DOCSTRING = r"""
+ Args:
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the video's (height, width) dimensions to the specified `size`. Can be overridden by the
+ `do_resize` parameter in the `preprocess` method.
+ size (`dict`, *optional*, defaults to `self.size`):
+ Size of the output video after resizing. Can be overridden by the `size` parameter in the `preprocess`
+ method.
+ size_divisor (`int`, *optional*, defaults to `self.size_divisor`):
+ The size by which to make sure both the height and width can be divided.
+ default_to_square (`bool`, *optional*, defaults to `self.default_to_square`):
+ Whether to default to a square video when resizing, if size is an int.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the video. Only has an effect if `do_resize` is set to `True`. Can be
+ overridden by the `resample` parameter in the `preprocess` method.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the video to the specified `crop_size`. Can be overridden by `do_center_crop` in the
+ `preprocess` method.
+ crop_size (`dict[str, int]` *optional*, defaults to `self.crop_size`):
+ Size of the output video after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
+ method.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the video by the specified scale `rescale_factor`. Can be overridden by the
+ `do_rescale` parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`):
+ Scale factor to use if rescaling the video. Only has an effect if `do_rescale` is set to `True`. Can be
+ overridden by the `rescale_factor` parameter in the `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the video. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Mean to use if normalizing the video. This is a float or list of floats the length of the number of
+ channels in the video. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
+ overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Standard deviation to use if normalizing the video. This is a float or list of floats the length of the
+ number of channels in the video. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.image_std`):
+ Whether to convert the video to RGB.
+ video_metadata (`VideoMetadata`, *optional*):
+ Metadata of the video containing information about total duration, fps and total number of frames.
+ do_sample_frames (`int`, *optional*, defaults to `self.do_sample_frames`):
+ Whether to sample frames from the video before processing or to process the whole video.
+ num_frames (`int`, *optional*, defaults to `self.num_frames`):
+ Maximum number of frames to sample when `do_sample_frames=True`.
+ fps (`int` or `float`, *optional*, defaults to `self.fps`):
+ Target frames to sample per second when `do_sample_frames=True`.
+ return_tensors (`str` or `TensorType`, *optional*):
+ Returns stacked tensors if set to `pt, otherwise returns a list of tensors.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output video. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: video in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input video.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input video. If unset, the channel dimension format is inferred
+ from the input video. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: video in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: video in (height, width) format.
+ device (`torch.device`, *optional*):
+ The device to process the videos on. If unset, the device is inferred from the input videos.
+ return_metadata (`bool`, *optional*):
+ Whether to return video metadata or not.
+ """
+
+
+@add_start_docstrings(
+ "Constructs a base VideoProcessor.",
+ BASE_VIDEO_PROCESSOR_DOCSTRING,
+)
+@requires(backends=("vision", "torchvision"))
+class BaseVideoProcessor(BaseImageProcessorFast):
+ _auto_class = None
+
+ resample = None
+ image_mean = None
+ image_std = None
+ size = None
+ size_divisor = None
+ default_to_square = True
+ crop_size = None
+ do_resize = None
+ do_center_crop = None
+ do_rescale = None
+ rescale_factor = 1 / 255
+ do_normalize = None
+ do_convert_rgb = None
+ do_sample_frames = None
+ fps = None
+ num_frames = None
+ video_metadata = None
+ return_metadata = False
+ valid_kwargs = VideosKwargs
+ model_input_names = ["pixel_values_videos"]
+
+ def __init__(self, **kwargs: Unpack[VideosKwargs]) -> None:
+ super().__init__()
+
+ self._processor_class = kwargs.pop("processor_class", None)
+
+ # Additional attributes without default values
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ logger.error(f"Can't set {key} with value {value} for {self}")
+ raise err
+
+ # Prepare size related keys and turn then into `SizeDict`
+ size = kwargs.pop("size", self.size)
+ self.size = (
+ get_size_dict(size=size, default_to_square=kwargs.pop("default_to_square", self.default_to_square))
+ if size is not None
+ else None
+ )
+ crop_size = kwargs.pop("crop_size", self.crop_size)
+ self.crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None
+
+ # Save valid kwargs in a list for further processing
+ self.model_valid_processing_keys = list(self.valid_kwargs.__annotations__.keys())
+ for key in self.model_valid_processing_keys:
+ if kwargs.get(key) is not None:
+ setattr(self, key, kwargs[key])
+ else:
+ setattr(self, key, deepcopy(getattr(self, key, None)))
+
+ def __call__(self, videos, **kwargs) -> BatchFeature:
+ return self.preprocess(videos, **kwargs)
+
+ def convert_to_rgb(
+ self,
+ video: "torch.Tensor",
+ ) -> VideoInput:
+ """
+ Converts a video to RGB format.
+
+ Args:
+ video (`"torch.Tensor"`):
+ The video to convert.
+
+ Returns:
+ `torch.Tensor`: The converted video.
+ """
+
+ video = F.grayscale_to_rgb(video)
+ if video.shape[-3] == 3 or not (video[..., 3, :, :] < 255).any():
+ return video
+
+ # There is a transparency layer, blend it with a white background.
+ # Calculate the alpha proportion for blending.
+ alpha = video[..., 3, :, :] / 255.0
+ video = (1 - alpha[..., None, :, :]) * 255 + alpha[..., None, :, :] * video[..., :3, :, :]
+ return video
+
+ def sample_frames(
+ self,
+ metadata: VideoMetadata,
+ num_frames: Optional[int] = None,
+ fps: Optional[Union[int, float]] = None,
+ **kwargs,
+ ):
+ """
+ Default sampling function which uniformly samples the desired number of frames between 0 and total number of frames.
+ If `fps` is passed along with metadata, `fps` frames per second are sampled uniformty. Arguments `num_frames`
+ and `fps` are mutually exclusive.
+
+ Args:
+ metadata (`VideoMetadata`):
+ Metadata of the video containing information about total duration, fps and total number of frames.
+ num_frames (`int`, *optional*):
+ Maximum number of frames to sample. Defaults to `self.num_frames`.
+ fps (`int` or `float`, *optional*):
+ Target frames to sample per second. Defaults to `self.fps`.
+
+ Returns:
+ np.ndarray:
+ Indices to sample video frames.
+ """
+ if fps is not None and num_frames is not None:
+ raise ValueError(
+ "`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!"
+ )
+
+ num_frames = num_frames if num_frames is not None else self.num_frames
+ fps = fps if fps is not None else self.fps
+ total_num_frames = metadata.total_num_frames
+
+ # If num_frames is not given but fps is, calculate num_frames from fps
+ if num_frames is None and fps is not None:
+ if metadata is None or metadata.fps is None:
+ raise ValueError(
+ "Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
+ "Please pass in `VideoMetadata` object or use a fixed `num_frames` per input video"
+ )
+ num_frames = int(total_num_frames / metadata.fps * fps)
+
+ if num_frames > total_num_frames:
+ raise ValueError(
+ f"Video can't be sampled. The `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. "
+ )
+
+ if num_frames is not None:
+ indices = torch.arange(0, total_num_frames, total_num_frames / num_frames).int()
+ else:
+ indices = torch.arange(0, total_num_frames).int()
+ return indices
+
+ def _decode_and_sample_videos(
+ self,
+ videos: VideoInput,
+ video_metadata: Union[VideoMetadata, dict],
+ do_sample_frames: Optional[bool] = None,
+ sample_indices_fn: Optional[Callable] = None,
+ ) -> list["torch.Tensor"]:
+ """
+ Decode input videos and sample frames if needed.
+ """
+ videos = make_batched_videos(videos)
+ video_metadata = make_batched_metadata(videos, video_metadata=video_metadata)
+
+ # Only sample frames if an array video is passed, otherwise first decode -> then sample
+ if is_valid_video(videos[0]) and do_sample_frames:
+ sampled_videos = []
+ sampled_metadata = []
+ for video, metadata in zip(videos, video_metadata):
+ indices = sample_indices_fn(metadata=metadata)
+ metadata.frames_indices = indices
+ sampled_videos.append(video[indices])
+ sampled_metadata.append(metadata)
+ videos = sampled_videos
+ video_metadata = sampled_metadata
+ elif not is_valid_video(videos[0]):
+ if isinstance(videos[0], list):
+ # Videos sometimes are passed as a list of image URLs, especially through templates
+ videos = [
+ torch.stack([F.pil_to_tensor(image) for image in images], dim=0)
+ for images in self.fetch_images(videos)
+ ]
+ if do_sample_frames:
+ raise ValueError(
+ "Sampling frames from a list of images is not supported! Set `do_sample_frames=False`."
+ )
+ else:
+ videos, video_metadata = self.fetch_videos(videos, sample_indices_fn=sample_indices_fn)
+
+ return videos, video_metadata
+
+ def _prepare_input_videos(
+ self,
+ videos: VideoInput,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ device: Optional[str] = None,
+ ) -> list["torch.Tensor"]:
+ """
+ Prepare the input videos for processing.
+ """
+ processed_videos = []
+ for video in videos:
+ # `make_batched_videos` always returns a 4D array per video
+ if isinstance(video, np.ndarray):
+ video = to_channel_dimension_format(video, ChannelDimension.FIRST, input_data_format)
+ # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
+ video = torch.from_numpy(video).contiguous()
+
+ if device is not None:
+ video = video.to(device)
+
+ processed_videos.append(video)
+ return processed_videos
+
+ @add_start_docstrings(
+ BASE_VIDEO_PROCESSOR_DOCSTRING,
+ )
+ def preprocess(
+ self,
+ videos: VideoInput,
+ **kwargs: Unpack[VideosKwargs],
+ ) -> BatchFeature:
+ validate_kwargs(
+ captured_kwargs=kwargs.keys(),
+ valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"],
+ )
+ # Set default kwargs from self. This ensures that if a kwarg is not provided
+ # by the user, it gets its default value from the instance, or is set to None.
+ for kwarg_name in self.valid_kwargs.__annotations__:
+ kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
+
+ input_data_format = kwargs.pop("input_data_format")
+ do_sample_frames = kwargs.pop("do_sample_frames")
+ device = kwargs.pop("device")
+ video_metadata = kwargs.pop("video_metadata")
+
+ sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None
+ videos, video_metadata = self._decode_and_sample_videos(
+ videos,
+ video_metadata=video_metadata,
+ do_sample_frames=do_sample_frames,
+ sample_indices_fn=sample_indices_fn,
+ )
+ videos = self._prepare_input_videos(videos=videos, input_data_format=input_data_format, device=device)
+
+ kwargs = self._further_process_kwargs(**kwargs)
+ self._validate_preprocess_kwargs(**kwargs)
+
+ # Pop kwargs that are not needed in _preprocess
+ kwargs.pop("data_format")
+ return_metadata = kwargs.pop("return_metadata")
+
+ preprocessed_videos = self._preprocess(videos=videos, **kwargs)
+ if return_metadata:
+ preprocessed_videos["video_metadata"] = video_metadata
+ return preprocessed_videos
+
+ def _preprocess(
+ self,
+ videos: list["torch.Tensor"],
+ do_convert_rgb: bool,
+ do_resize: bool,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"],
+ do_center_crop: bool,
+ crop_size: SizeDict,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs,
+ ) -> BatchFeature:
+ # Group videos by size for batched resizing
+ grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
+ resized_videos_grouped = {}
+ for shape, stacked_videos in grouped_videos.items():
+ if do_convert_rgb:
+ stacked_videos = self.convert_to_rgb(stacked_videos)
+ if do_resize:
+ stacked_videos = self.resize(stacked_videos, size=size, interpolation=interpolation)
+ resized_videos_grouped[shape] = stacked_videos
+ resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index)
+
+ # Group videos by size for further processing
+ # Needed in case do_resize is False, or resize returns videos with different sizes
+ grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos)
+ processed_videos_grouped = {}
+ for shape, stacked_videos in grouped_videos.items():
+ if do_center_crop:
+ stacked_videos = self.center_crop(stacked_videos, crop_size)
+ # Fused rescale and normalize
+ stacked_videos = self.rescale_and_normalize(
+ stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_videos_grouped[shape] = stacked_videos
+
+ processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index)
+ processed_videos = torch.stack(processed_videos, dim=0) if return_tensors else processed_videos
+
+ return BatchFeature(data={"pixel_values_videos": processed_videos}, tensor_type=return_tensors)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ token: Optional[Union[str, bool]] = None,
+ revision: str = "main",
+ **kwargs,
+ ):
+ r"""
+ Instantiate a type of [`~video_processing_utils.VideoProcessorBase`] from an video processor.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained video hosted inside a model repo on
+ huggingface.co.
+ - a path to a *directory* containing a video processor file saved using the
+ [`~video_processing_utils.VideoProcessorBase.save_pretrained`] method, e.g.,
+ `./my_model_directory/`.
+ - a path or url to a saved video processor JSON *file*, e.g.,
+ `./my_model_directory/video_preprocessor_config.json`.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model video processor should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the video processor files and override the cached versions if
+ they exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+ the token generated when running `hf auth login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+
+
+
+
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`.
+
+
+
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ If `False`, then this function returns just the final video processor object. If `True`, then this
+ functions returns a `Tuple(video_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
+ consisting of the key/value pairs whose keys are not video processor attributes: i.e., the part of
+ `kwargs` which has not been used to update `video_processor` and is otherwise ignored.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
+ kwargs (`dict[str, Any]`, *optional*):
+ The values in kwargs of any keys which are video processor attributes will be used to override the
+ loaded values. Behavior concerning key/value pairs whose keys are *not* video processor attributes is
+ controlled by the `return_unused_kwargs` keyword parameter.
+
+ Returns:
+ A video processor of type [`~video_processing_utils.ImagVideoProcessorBase`].
+
+ Examples:
+
+ ```python
+ # We can't instantiate directly the base class *VideoProcessorBase* so let's show the examples on a
+ # derived class: *LlavaOnevisionVideoProcessor*
+ video_processor = LlavaOnevisionVideoProcessor.from_pretrained(
+ "llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
+ ) # Download video_processing_config from huggingface.co and cache.
+ video_processor = LlavaOnevisionVideoProcessor.from_pretrained(
+ "./test/saved_model/"
+ ) # E.g. video processor (or model) was saved using *save_pretrained('./test/saved_model/')*
+ video_processor = LlavaOnevisionVideoProcessor.from_pretrained("./test/saved_model/video_preprocessor_config.json")
+ video_processor = LlavaOnevisionVideoProcessor.from_pretrained(
+ "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", do_normalize=False, foo=False
+ )
+ assert video_processor.do_normalize is False
+ video_processor, unused_kwargs = LlavaOnevisionVideoProcessor.from_pretrained(
+ "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", do_normalize=False, foo=False, return_unused_kwargs=True
+ )
+ assert video_processor.do_normalize is False
+ assert unused_kwargs == {"foo": False}
+ ```"""
+ kwargs["cache_dir"] = cache_dir
+ kwargs["force_download"] = force_download
+ kwargs["local_files_only"] = local_files_only
+ kwargs["revision"] = revision
+
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ if token is not None:
+ kwargs["token"] = token
+
+ video_processor_dict, kwargs = cls.get_video_processor_dict(pretrained_model_name_or_path, **kwargs)
+
+ return cls.from_dict(video_processor_dict, **kwargs)
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save an video processor object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~video_processing_utils.VideoProcessorBase.from_pretrained`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the video processor JSON file will be saved (will be created if it does not exist).
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ kwargs (`dict[str, Any]`, *optional*):
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ """
+ use_auth_token = kwargs.pop("use_auth_token", None)
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if kwargs.get("token") is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ kwargs["token"] = use_auth_token
+
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ if push_to_hub:
+ commit_message = kwargs.pop("commit_message", None)
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+ repo_id = self._create_repo(repo_id, **kwargs)
+ files_timestamps = self._get_files_timestamps(save_directory)
+
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
+ # loaded from the Hub.
+ if self._auto_class is not None:
+ custom_object_save(self, save_directory, config=self)
+
+ # If we save using the predefined names, we can load using `from_pretrained`
+ output_video_processor_file = os.path.join(save_directory, VIDEO_PROCESSOR_NAME)
+
+ self.to_json_file(output_video_processor_file)
+ logger.info(f"Video processor saved in {output_video_processor_file}")
+
+ if push_to_hub:
+ self._upload_modified_files(
+ save_directory,
+ repo_id,
+ files_timestamps,
+ commit_message=commit_message,
+ token=kwargs.get("token"),
+ )
+
+ return [output_video_processor_file]
+
+ @classmethod
+ def get_video_processor_dict(
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
+ """
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
+ video processor of type [`~video_processing_utils.VideoProcessorBase`] using `from_dict`.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
+
+ Returns:
+ `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the video processor object.
+ """
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", None)
+ proxies = kwargs.pop("proxies", None)
+ token = kwargs.pop("token", None)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", "")
+
+ from_pipeline = kwargs.pop("_from_pipeline", None)
+ from_auto_class = kwargs.pop("_from_auto", False)
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ user_agent = {"file_type": "video processor", "from_auto_class": from_auto_class}
+ if from_pipeline is not None:
+ user_agent["using_pipeline"] = from_pipeline
+
+ if is_offline_mode() and not local_files_only:
+ logger.info("Offline mode: forcing local_files_only=True")
+ local_files_only = True
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if os.path.isfile(pretrained_model_name_or_path):
+ resolved_video_processor_file = pretrained_model_name_or_path
+ is_local = True
+ elif is_remote_url(pretrained_model_name_or_path):
+ video_processor_file = pretrained_model_name_or_path
+ resolved_video_processor_file = download_url(pretrained_model_name_or_path)
+ else:
+ video_processor_file = VIDEO_PROCESSOR_NAME
+ try:
+ # Try to load with a new config name first and if not successful try with the old file name
+ # NOTE: we will gradually change to saving all processor configs as nested dict in PROCESSOR_NAME
+ resolved_video_processor_files = [
+ resolved_file
+ for filename in [VIDEO_PROCESSOR_NAME, IMAGE_PROCESSOR_NAME, PROCESSOR_NAME]
+ if (
+ resolved_file := cached_file(
+ pretrained_model_name_or_path,
+ filename=filename,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ token=token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ _raise_exceptions_for_missing_entries=False,
+ )
+ )
+ is not None
+ ]
+ resolved_video_processor_file = resolved_video_processor_files[0]
+ except OSError:
+ # Raise any OS error raise by `cached_file`. It will have a helpful error message adapted to
+ # the original exception.
+ raise
+ except Exception:
+ # For any other exception, we throw a generic error.
+ raise OSError(
+ f"Can't load video processor for '{pretrained_model_name_or_path}'. If you were trying to load"
+ " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+ f" directory containing a {VIDEO_PROCESSOR_NAME} file"
+ )
+
+ try:
+ # Load video_processor dict
+ with open(resolved_video_processor_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ video_processor_dict = json.loads(text)
+ video_processor_dict = video_processor_dict.get("video_processor", video_processor_dict)
+
+ except json.JSONDecodeError:
+ raise OSError(
+ f"It looks like the config file at '{resolved_video_processor_file}' is not a valid JSON file."
+ )
+
+ if is_local:
+ logger.info(f"loading configuration file {resolved_video_processor_file}")
+ else:
+ logger.info(
+ f"loading configuration file {video_processor_file} from cache at {resolved_video_processor_file}"
+ )
+ return video_processor_dict, kwargs
+
+ @classmethod
+ def from_dict(cls, video_processor_dict: dict[str, Any], **kwargs):
+ """
+ Instantiates a type of [`~video_processing_utils.VideoProcessorBase`] from a Python dictionary of parameters.
+
+ Args:
+ video_processor_dict (`dict[str, Any]`):
+ Dictionary that will be used to instantiate the video processor object. Such a dictionary can be
+ retrieved from a pretrained checkpoint by leveraging the
+ [`~video_processing_utils.VideoProcessorBase.to_dict`] method.
+ kwargs (`dict[str, Any]`):
+ Additional parameters from which to initialize the video processor object.
+
+ Returns:
+ [`~video_processing_utils.VideoProcessorBase`]: The video processor object instantiated from those
+ parameters.
+ """
+ video_processor_dict = video_processor_dict.copy()
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
+
+ # The `size` parameter is a dict and was previously an int or tuple in feature extractors.
+ # We set `size` here directly to the `video_processor_dict` so that it is converted to the appropriate
+ # dict within the video processor and isn't overwritten if `size` is passed in as a kwarg.
+ if "size" in kwargs and "size" in video_processor_dict:
+ video_processor_dict["size"] = kwargs.pop("size")
+ if "crop_size" in kwargs and "crop_size" in video_processor_dict:
+ video_processor_dict["crop_size"] = kwargs.pop("crop_size")
+
+ video_processor = cls(**video_processor_dict)
+
+ # Update video_processor with kwargs if needed
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(video_processor, key):
+ setattr(video_processor, key, value)
+ to_remove.append(key)
+ for key in to_remove:
+ kwargs.pop(key, None)
+
+ logger.info(f"Video processor {video_processor}")
+ if return_unused_kwargs:
+ return video_processor, kwargs
+ else:
+ return video_processor
+
+ def to_dict(self) -> dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary.
+
+ Returns:
+ `dict[str, Any]`: Dictionary of all the attributes that make up this video processor instance.
+ """
+ output = deepcopy(self.__dict__)
+ output.pop("model_valid_processing_keys", None)
+ output.pop("_valid_kwargs_names", None)
+ output["video_processor_type"] = self.__class__.__name__
+
+ return output
+
+ def to_json_string(self) -> str:
+ """
+ Serializes this instance to a JSON string.
+
+ Returns:
+ `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
+ """
+ dictionary = self.to_dict()
+
+ for key, value in dictionary.items():
+ if isinstance(value, np.ndarray):
+ dictionary[key] = value.tolist()
+
+ # make sure private name "_processor_class" is correctly
+ # saved as "processor_class"
+ _processor_class = dictionary.pop("_processor_class", None)
+ if _processor_class is not None:
+ dictionary["processor_class"] = _processor_class
+
+ return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this image_processor instance's parameters will be saved.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
+
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ @classmethod
+ def from_json_file(cls, json_file: Union[str, os.PathLike]):
+ """
+ Instantiates a video processor of type [`~video_processing_utils.VideoProcessorBase`] from the path to a JSON
+ file of parameters.
+
+ Args:
+ json_file (`str` or `os.PathLike`):
+ Path to the JSON file containing the parameters.
+
+ Returns:
+ A video processor of type [`~video_processing_utils.VideoProcessorBase`]: The video_processor object
+ instantiated from that JSON file.
+ """
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ video_processor_dict = json.loads(text)
+ return cls(**video_processor_dict)
+
+ @classmethod
+ def register_for_auto_class(cls, auto_class="AutoVideoProcessor"):
+ """
+ Register this class with a given auto class. This should only be used for custom video processors as the ones
+ in the library are already mapped with `AutoVideoProcessor `.
+
+
+
+ This API is experimental and may have some slight breaking changes in the next releases.
+
+
+
+ Args:
+ auto_class (`str` or `type`, *optional*, defaults to `"AutoVideoProcessor "`):
+ The auto class to register this new video processor with.
+ """
+ if not isinstance(auto_class, str):
+ auto_class = auto_class.__name__
+
+ import transformers.models.auto as auto_module
+
+ if not hasattr(auto_module, auto_class):
+ raise ValueError(f"{auto_class} is not a valid auto class.")
+
+ cls._auto_class = auto_class
+
+ def fetch_videos(self, video_url_or_urls: Union[str, list[str], list[list[str]]], sample_indices_fn=None):
+ """
+ Convert a single or a list of urls into the corresponding `np.array` objects.
+
+ If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
+ returned.
+ """
+ backend = "torchcodec"
+ if not is_torchcodec_available():
+ warnings.warn(
+ "`torchcodec` is not installed and cannot be used to decode the video by default. "
+ "Falling back to `torchvision`. Note that `torchvision` decoding is deprecated and will be removed in future versions. "
+ )
+ backend = "torchvision"
+
+ if isinstance(video_url_or_urls, list):
+ return list(zip(*[self.fetch_videos(x, sample_indices_fn=sample_indices_fn) for x in video_url_or_urls]))
+ else:
+ return load_video(video_url_or_urls, backend=backend, sample_indices_fn=sample_indices_fn)
+
+
+BaseVideoProcessor.push_to_hub = copy_func(BaseVideoProcessor.push_to_hub)
+if BaseVideoProcessor.push_to_hub.__doc__ is not None:
+ BaseVideoProcessor.push_to_hub.__doc__ = BaseVideoProcessor.push_to_hub.__doc__.format(
+ object="video processor", object_class="AutoVideoProcessor", object_files="video processor file"
+ )
diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/video_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/video_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ed5720a8e410e061347dd4946fa59f5f9594bb1
--- /dev/null
+++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/video_utils.py
@@ -0,0 +1,878 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import warnings
+from collections.abc import Iterable, Mapping
+from contextlib import redirect_stdout
+from dataclasses import dataclass, fields
+from io import BytesIO
+from typing import Callable, NewType, Optional, Union
+from urllib.parse import urlparse
+
+import numpy as np
+import requests
+
+from .image_transforms import PaddingMode, to_channel_dimension_format
+from .image_utils import ChannelDimension, infer_channel_dimension_format, is_valid_image
+from .utils import (
+ is_av_available,
+ is_cv2_available,
+ is_decord_available,
+ is_numpy_array,
+ is_torch_available,
+ is_torch_tensor,
+ is_torchcodec_available,
+ is_torchvision_available,
+ is_vision_available,
+ is_yt_dlp_available,
+ logging,
+ requires_backends,
+)
+
+
+if is_vision_available():
+ import PIL.Image
+ import PIL.ImageOps
+
+ if is_torchvision_available():
+ from torchvision import io as torchvision_io
+
+if is_torch_available():
+ import torch
+
+
+logger = logging.get_logger(__name__)
+
+URL = NewType("URL", str)
+Path = NewType("Path", str)
+
+VideoInput = Union[
+ list["PIL.Image.Image"],
+ np.ndarray,
+ "torch.Tensor",
+ list[np.ndarray],
+ list["torch.Tensor"],
+ list[list["PIL.Image.Image"]],
+ list[list[np.ndarray]],
+ list[list["torch.Tensor"]],
+ URL,
+ list[URL],
+ list[list[URL]],
+ Path,
+ list[Path],
+ list[list[Path]],
+]
+
+
+@dataclass
+class VideoMetadata(Mapping):
+ total_num_frames: int
+ fps: Optional[float] = None
+ width: Optional[int] = None
+ height: Optional[int] = None
+ duration: Optional[float] = None
+ video_backend: Optional[str] = None
+ frames_indices: Optional[list[int]] = None
+
+ def __iter__(self):
+ return (f.name for f in fields(self))
+
+ def __len__(self):
+ return len(fields(self))
+
+ def __getitem__(self, item):
+ return getattr(self, item)
+
+ def __setitem__(self, key, value):
+ return setattr(self, key, value)
+
+ @property
+ def timestamps(self) -> list[float]:
+ "Timestamps of the sampled frames in seconds."
+ if self.fps is None or self.frames_indices is None:
+ raise ValueError("Cannot infer video `timestamps` when `fps` or `frames_indices` is None.")
+ return [frame_idx / self.fps for frame_idx in self.frames_indices]
+
+ def update(self, dictionary):
+ for key, value in dictionary.items():
+ if hasattr(self, key):
+ setattr(self, key, value)
+
+
+def is_valid_video_frame(frame):
+ return isinstance(frame, PIL.Image.Image) or (
+ (is_numpy_array(frame) or is_torch_tensor(frame)) and frame.ndim == 3
+ )
+
+
+def is_valid_video(video):
+ if not isinstance(video, (list, tuple)):
+ return (is_numpy_array(video) or is_torch_tensor(video)) and video.ndim == 4
+ return video and all(is_valid_video_frame(frame) for frame in video)
+
+
+def valid_videos(videos):
+ # If we have a list of videos, it could be either one video as list of frames or a batch
+ if isinstance(videos, (list, tuple)):
+ for video_or_frame in videos:
+ if not (is_valid_video(video_or_frame) or is_valid_video_frame(video_or_frame)):
+ return False
+ # If not a list, then we have a single 4D video or 5D batched tensor
+ elif not is_valid_video(videos) or videos.ndim == 5:
+ return False
+ return True
+
+
+def is_batched_video(videos):
+ if isinstance(videos, (list, tuple)):
+ return is_valid_video(videos[0])
+ elif (is_numpy_array(videos) or is_torch_tensor(videos)) and videos.ndim == 5:
+ return True
+ return False
+
+
+def is_scaled_video(video: np.ndarray) -> bool:
+ """
+ Checks to see whether the pixel values have already been rescaled to [0, 1].
+ """
+ # It's possible the video has pixel values in [0, 255] but is of floating type
+ return np.min(video) >= 0 and np.max(video) <= 1
+
+
+def convert_pil_frames_to_video(videos: list[VideoInput]) -> list[Union[np.ndarray, "torch.Tensor"]]:
+ """
+ Given a batch of videos, converts each video to a 4D array. If video is already in array type,
+ it is simply returned. We assume that all inputs in the list are in the same format, based on the type of the first element.
+
+ Args:
+ videos (`VideoInput`):
+ Video inputs to turn into a list of videos.
+ """
+
+ if not (isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0])):
+ return videos
+
+ video_converted = []
+ for video in videos:
+ video = [np.array(frame) for frame in video]
+ video = np.stack(video)
+ video_converted.append(video)
+ return video_converted
+
+
+def make_batched_videos(videos) -> list[Union[np.ndarray, "torch.Tensor", "URL", "Path"]]:
+ """
+ Ensure that the input is a list of videos. If the input is a single video, it is converted to a list of length 1.
+ If the input is a batch of videos, it is converted to a list of 4D video arrays. Videos passed as list `PIL.Image`
+ frames are converted to 4D arrays.
+
+ We assume that all inputs in the list are in the same format, based on the type of the first element.
+
+ Args:
+ videos (`VideoInput`):
+ Video inputs to turn into a list of videos.
+ """
+ # Early exit for deeply nested list of image frame paths. We shouldn't flatten them
+ try:
+ if isinstance(videos[0][0], list) and isinstance(videos[0][0][0], str):
+ return [image_paths for sublist in videos for image_paths in sublist]
+ except (IndexError, TypeError):
+ pass
+
+ if isinstance(videos, str) or is_valid_video(videos):
+ return convert_pil_frames_to_video([videos])
+ # only one frame passed, thus we unsqueeze time dim
+ elif is_valid_image(videos):
+ if isinstance(videos, PIL.Image.Image):
+ videos = np.array(videos)
+ return [videos[None, ...]]
+ elif not isinstance(videos, list):
+ raise ValueError(
+ f"Invalid video input. Expected either a list of video frames or an input of 4 or 5 dimensions, but got"
+ f" type {type(videos)}."
+ )
+
+ # Recursively flatten any nested structure
+ flat_videos_list = []
+ for item in videos:
+ if isinstance(item, str) or is_valid_video(item):
+ flat_videos_list.append(item)
+ elif isinstance(item, list) and item:
+ flat_videos_list.extend(make_batched_videos(item))
+
+ flat_videos_list = convert_pil_frames_to_video(flat_videos_list)
+ return flat_videos_list
+
+
+def make_batched_metadata(videos: VideoInput, video_metadata: Union[VideoMetadata, dict]):
+ if video_metadata is None:
+ # Create default metadata and fill attributes we can infer from given video
+ video_metadata = [
+ {
+ "total_num_frames": len(video),
+ "fps": None,
+ "duration": None,
+ "frames_indices": list(range(len(video))),
+ "height": get_video_size(video)[0] if is_valid_video(video) else None,
+ "width": get_video_size(video)[1] if is_valid_video(video) else None,
+ }
+ for video in videos
+ ]
+
+ if isinstance(video_metadata, list):
+ # Flatten if nested list
+ if isinstance(video_metadata[0], list):
+ video_metadata = [
+ VideoMetadata(**metadata) for metadata_list in video_metadata for metadata in metadata_list
+ ]
+ # Simply wrap in VideoMetadata if simple dict
+ elif isinstance(video_metadata[0], dict):
+ video_metadata = [VideoMetadata(**metadata) for metadata in video_metadata]
+ else:
+ # Create a batched list from single object
+ video_metadata = [VideoMetadata(**video_metadata)]
+ return video_metadata
+
+
+def get_video_size(video: np.ndarray, channel_dim: Optional[ChannelDimension] = None) -> tuple[int, int]:
+ """
+ Returns the (height, width) dimensions of the video.
+
+ Args:
+ video (`np.ndarray`):
+ The video to get the dimensions of.
+ channel_dim (`ChannelDimension`, *optional*):
+ Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the video.
+
+ Returns:
+ A tuple of the video's height and width.
+ """
+ if channel_dim is None:
+ channel_dim = infer_channel_dimension_format(video, num_channels=(1, 3, 4))
+
+ if channel_dim == ChannelDimension.FIRST:
+ return video.shape[-2], video.shape[-1]
+ elif channel_dim == ChannelDimension.LAST:
+ return video.shape[-3], video.shape[-2]
+ else:
+ raise ValueError(f"Unsupported data format: {channel_dim}")
+
+
+def get_uniform_frame_indices(total_num_frames: int, num_frames: Optional[int] = None):
+ """
+ Creates a numpy array for uniform sampling of `num_frame` frames from `total_num_frames`
+ when loading a video.
+
+ Args:
+ total_num_frames (`int`):
+ Total number of frames that a video has.
+ num_frames (`int`, *optional*):
+ Number of frames to sample uniformly. If not specified, all frames are sampled.
+
+ Returns:
+ np.ndarray: np array of frame indices that will be sampled.
+ """
+ if num_frames is not None:
+ indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int)
+ else:
+ indices = np.arange(0, total_num_frames).astype(int)
+ return indices
+
+
+def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs):
+ """
+ A default sampling function that replicates the logic used in get_uniform_frame_indices,
+ while optionally handling `fps` if `num_frames` is not provided.
+
+ Args:
+ metadata (`VideoMetadata`):
+ `VideoMetadata` object containing metadata about the video, such as "total_num_frames" or "fps".
+ num_frames (`int`, *optional*):
+ Number of frames to sample uniformly.
+ fps (`int` or `float`, *optional*):
+ Desired frames per second. Takes priority over num_frames if both are provided.
+
+ Returns:
+ `np.ndarray`: Array of frame indices to sample.
+ """
+ total_num_frames = metadata.total_num_frames
+ video_fps = metadata.fps
+
+ # If num_frames is not given but fps is, calculate num_frames from fps
+ if num_frames is None and fps is not None:
+ num_frames = int(total_num_frames / video_fps * fps)
+ if num_frames > total_num_frames:
+ raise ValueError(
+ f"When loading the video with fps={fps}, we computed num_frames={num_frames} "
+ f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata."
+ )
+
+ if num_frames is not None:
+ indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int)
+ else:
+ indices = np.arange(0, total_num_frames, dtype=int)
+ return indices
+
+
+def read_video_opencv(
+ video_path: Union["URL", "Path"],
+ sample_indices_fn: Callable,
+ **kwargs,
+) -> tuple[np.ndarray, VideoMetadata]:
+ """
+ Decode a video using the OpenCV backend.
+
+ Args:
+ video_path (`str`):
+ Path to the video file.
+ sample_indices_fn (`Callable`):
+ A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
+ by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
+ If not provided, simple uniform sampling with fps is performed.
+ Example:
+ def sample_indices_fn(metadata, **kwargs):
+ return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
+
+ Returns:
+ tuple[`np.ndarray`, `VideoMetadata`]: A tuple containing:
+ - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
+ - `VideoMetadata` object.
+ """
+ # Lazy import cv2
+ requires_backends(read_video_opencv, ["cv2"])
+ import cv2
+
+ video = cv2.VideoCapture(video_path)
+ total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
+ video_fps = video.get(cv2.CAP_PROP_FPS)
+ duration = total_num_frames / video_fps if video_fps else 0
+ metadata = VideoMetadata(
+ total_num_frames=int(total_num_frames),
+ fps=float(video_fps),
+ duration=float(duration),
+ video_backend="opencv",
+ height=int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)),
+ width=int(video.get(cv2.CAP_PROP_FRAME_WIDTH)),
+ )
+ indices = sample_indices_fn(metadata=metadata, **kwargs)
+
+ index = 0
+ frames = []
+ while video.isOpened():
+ success, frame = video.read()
+ if not success:
+ break
+ if index in indices:
+ height, width, channel = frame.shape
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ frames.append(frame[0:height, 0:width, 0:channel])
+ if success:
+ index += 1
+ if index >= total_num_frames:
+ break
+
+ video.release()
+ metadata.frames_indices = indices
+ return np.stack(frames), metadata
+
+
+def read_video_decord(
+ video_path: Union["URL", "Path"],
+ sample_indices_fn: Callable,
+ **kwargs,
+):
+ """
+ Decode a video using the Decord backend.
+
+ Args:
+ video_path (`str`):
+ Path to the video file.
+ sample_indices_fn (`Callable`):
+ A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
+ by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
+ If not provided, simple uniform sampling with fps is performed.
+ Example:
+ def sample_indices_fn(metadata, **kwargs):
+ return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
+
+ Returns:
+ tuple[`np.array`, `VideoMetadata`]: A tuple containing:
+ - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
+ - `VideoMetadata` object.
+ """
+ # Lazy import from decord
+ requires_backends(read_video_decord, ["decord"])
+ from decord import VideoReader, cpu
+
+ vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
+ video_fps = vr.get_avg_fps()
+ total_num_frames = len(vr)
+ duration = total_num_frames / video_fps if video_fps else 0
+ metadata = VideoMetadata(
+ total_num_frames=int(total_num_frames),
+ fps=float(video_fps),
+ duration=float(duration),
+ video_backend="decord",
+ )
+
+ indices = sample_indices_fn(metadata=metadata, **kwargs)
+ video = vr.get_batch(indices).asnumpy()
+
+ metadata.update(
+ {
+ "frames_indices": indices,
+ "height": video.shape[1],
+ "width": video.shape[2],
+ }
+ )
+ return video, metadata
+
+
+def read_video_pyav(
+ video_path: Union["URL", "Path"],
+ sample_indices_fn: Callable,
+ **kwargs,
+):
+ """
+ Decode the video with PyAV decoder.
+
+ Args:
+ video_path (`str`):
+ Path to the video file.
+ sample_indices_fn (`Callable`, *optional*):
+ A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
+ by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
+ If not provided, simple uniform sampling with fps is performed.
+ Example:
+ def sample_indices_fn(metadata, **kwargs):
+ return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
+
+ Returns:
+ tuple[`np.array`, `VideoMetadata`]: A tuple containing:
+ - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
+ - `VideoMetadata` object.
+ """
+ # Lazy import av
+ requires_backends(read_video_pyav, ["av"])
+ import av
+
+ container = av.open(video_path)
+ total_num_frames = container.streams.video[0].frames
+ video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`?
+ duration = total_num_frames / video_fps if video_fps else 0
+ metadata = VideoMetadata(
+ total_num_frames=int(total_num_frames),
+ fps=float(video_fps),
+ duration=float(duration),
+ video_backend="pyav",
+ height=container.streams.video[0].height,
+ width=container.streams.video[0].width,
+ )
+ indices = sample_indices_fn(metadata=metadata, **kwargs)
+
+ frames = []
+ container.seek(0)
+ end_index = indices[-1]
+ for i, frame in enumerate(container.decode(video=0)):
+ if i > end_index:
+ break
+ if i >= 0 and i in indices:
+ frames.append(frame)
+
+ video = np.stack([x.to_ndarray(format="rgb24") for x in frames])
+ metadata.frames_indices = indices
+ return video, metadata
+
+
+def read_video_torchvision(
+ video_path: Union["URL", "Path"],
+ sample_indices_fn: Callable,
+ **kwargs,
+):
+ """
+ Decode the video with torchvision decoder.
+
+ Args:
+ video_path (`str`):
+ Path to the video file.
+ sample_indices_fn (`Callable`, *optional*):
+ A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
+ by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
+ If not provided, simple uniform sampling with fps is performed.
+ Example:
+ def sample_indices_fn(metadata, **kwargs):
+ return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
+
+ Returns:
+ tuple[`torch.Tensor`, `VideoMetadata`]: A tuple containing:
+ - Torch tensor of frames in RGB (shape: [num_frames, height, width, 3]).
+ - `VideoMetadata` object.
+ """
+ warnings.warn(
+ "Using `torchvision` for video decoding is deprecated and will be removed in future versions. "
+ "Please use `torchcodec` instead."
+ )
+ video, _, info = torchvision_io.read_video(
+ video_path,
+ start_pts=0.0,
+ end_pts=None,
+ pts_unit="sec",
+ output_format="TCHW",
+ )
+ video_fps = info["video_fps"]
+ total_num_frames = video.size(0)
+ duration = total_num_frames / video_fps if video_fps else 0
+ metadata = VideoMetadata(
+ total_num_frames=int(total_num_frames),
+ fps=float(video_fps),
+ duration=float(duration),
+ video_backend="torchvision",
+ )
+
+ indices = sample_indices_fn(metadata=metadata, **kwargs)
+
+ video = video[indices].contiguous()
+ metadata.update(
+ {
+ "frames_indices": indices,
+ "height": video.shape[2],
+ "width": video.shape[3],
+ }
+ )
+ return video, metadata
+
+
+def read_video_torchcodec(
+ video_path: Union["URL", "Path"],
+ sample_indices_fn: Callable,
+ **kwargs,
+):
+ """
+ Decode the video with torchcodec decoder.
+
+ Args:
+ video_path (`str`):
+ Path to the video file.
+ sample_indices_fn (`Callable`):
+ A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
+ by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
+ If not provided, simple uniform sampling with fps is performed.
+ Example:
+ def sample_indices_fn(metadata, **kwargs):
+ return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
+
+ Returns:
+ Tuple[`torch.Tensor`, `VideoMetadata`]: A tuple containing:
+ - Torch tensor of frames in RGB (shape: [num_frames, height, width, 3]).
+ - `VideoMetadata` object.
+ """
+ # Lazy import torchcodec
+ requires_backends(read_video_torchcodec, ["torchcodec"])
+ from torchcodec.decoders import VideoDecoder
+
+ decoder = VideoDecoder(
+ video_path,
+ # Interestingly `exact` mode takes less than approximate when we load the whole video
+ seek_mode="exact",
+ # Allow FFmpeg decide on the number of threads for efficiency
+ num_ffmpeg_threads=0,
+ device=kwargs.get("device"),
+ )
+ metadata = VideoMetadata(
+ total_num_frames=decoder.metadata.num_frames,
+ fps=decoder.metadata.average_fps,
+ duration=decoder.metadata.duration_seconds,
+ video_backend="torchcodec",
+ height=decoder.metadata.height,
+ width=decoder.metadata.width,
+ )
+ indices = sample_indices_fn(metadata=metadata, **kwargs)
+
+ video = decoder.get_frames_at(indices=indices).data.contiguous()
+ metadata.frames_indices = indices
+ return video, metadata
+
+
+VIDEO_DECODERS = {
+ "decord": read_video_decord,
+ "opencv": read_video_opencv,
+ "pyav": read_video_pyav,
+ "torchvision": read_video_torchvision,
+ "torchcodec": read_video_torchcodec,
+}
+
+
+def load_video(
+ video: VideoInput,
+ num_frames: Optional[int] = None,
+ fps: Optional[Union[int, float]] = None,
+ backend: str = "pyav",
+ sample_indices_fn: Optional[Callable] = None,
+ **kwargs,
+) -> np.ndarray:
+ """
+ Loads `video` to a numpy array.
+
+ Args:
+ video (`VideoInput`):
+ The video to convert to the numpy array format. Can be a link to video or local path.
+ num_frames (`int`, *optional*):
+ Number of frames to sample uniformly. If not passed, the whole video is loaded.
+ fps (`int` or `float`, *optional*):
+ Number of frames to sample per second. Should be passed only when `num_frames=None`.
+ If not specified and `num_frames==None`, all frames are sampled.
+ backend (`str`, *optional*, defaults to `"pyav"`):
+ The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision", "torchcodec"]. Defaults to "pyav".
+ sample_indices_fn (`Callable`, *optional*):
+ A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
+ by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
+ If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args.
+ The function expects at input the all args along with all kwargs passed to `load_video` and should output valid
+ indices at which the video should be sampled. For example:
+
+ Example:
+ def sample_indices_fn(metadata, **kwargs):
+ return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
+
+ Returns:
+ tuple[`np.ndarray`, Dict]: A tuple containing:
+ - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
+ - Metadata dictionary.
+ """
+
+ # If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn`
+ if fps is not None and num_frames is not None and sample_indices_fn is None:
+ raise ValueError(
+ "`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!"
+ )
+
+ # If user didn't pass a sampling function, create one on the fly with default logic
+ if sample_indices_fn is None:
+
+ def sample_indices_fn_func(metadata, **fn_kwargs):
+ return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs)
+
+ sample_indices_fn = sample_indices_fn_func
+
+ # Early exit if provided an array or `PIL` frames
+ if not isinstance(video, str):
+ metadata = [None] * len(video)
+ return video, metadata
+
+ if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
+ if not is_yt_dlp_available():
+ raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
+ # Lazy import from yt_dlp
+ requires_backends(load_video, ["yt_dlp"])
+ from yt_dlp import YoutubeDL
+
+ buffer = BytesIO()
+ with redirect_stdout(buffer), YoutubeDL() as f:
+ f.download([video])
+ bytes_obj = buffer.getvalue()
+ file_obj = BytesIO(bytes_obj)
+ elif video.startswith("http://") or video.startswith("https://"):
+ file_obj = BytesIO(requests.get(video).content)
+ elif os.path.isfile(video):
+ file_obj = video
+ else:
+ raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
+
+ # can also load with decord, but not cv2/torchvision
+ # both will fail in case of url links
+ video_is_url = video.startswith("http://") or video.startswith("https://")
+ if video_is_url and backend == "opencv":
+ raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend")
+
+ if (
+ (not is_decord_available() and backend == "decord")
+ or (not is_av_available() and backend == "pyav")
+ or (not is_cv2_available() and backend == "opencv")
+ or (not is_torchvision_available() and backend == "torchvision")
+ or (not is_torchcodec_available() and backend == "torchcodec")
+ ):
+ raise ImportError(
+ f"You chose backend={backend} for loading the video but the required library is not found in your environment "
+ f"Make sure to install {backend} before loading the video."
+ )
+
+ video_decoder = VIDEO_DECODERS[backend]
+ video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs)
+ return video, metadata
+
+
+def convert_to_rgb(
+ video: np.ndarray,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> np.ndarray:
+ """
+ Convert video to RGB by blending the transparency layer if it's in RGBA format, otherwise simply returns it.
+
+ Args:
+ video (`np.ndarray`):
+ The video to convert.
+ input_data_format (`ChannelDimension`, *optional*):
+ The channel dimension format of the input video. If unset, will use the inferred format from the input.
+ """
+ if not isinstance(video, np.ndarray):
+ raise TypeError(f"Video has to be a numpy array to convert to RGB format, but found {type(video)}")
+
+ # np.array usually comes with ChannelDimension.LAST so let's convert it
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(video)
+ video = to_channel_dimension_format(video, ChannelDimension.FIRST, input_channel_dim=input_data_format)
+
+ # 3 channels for RGB already
+ if video.shape[-3] == 3:
+ return video
+
+ # Grayscale video so we repeat it 3 times for each channel
+ if video.shape[-3] == 1:
+ return video.repeat(3, -3)
+
+ if not (video[..., 3, :, :] < 255).any():
+ return video
+
+ # There is a transparency layer, blend it with a white background.
+ # Calculate the alpha proportion for blending.
+ alpha = video[..., 3, :, :] / 255.0
+ video = (1 - alpha[..., None, :, :]) * 255 + alpha[..., None, :, :] * video[..., 3, :, :]
+ return video
+
+
+def pad(
+ video: np.ndarray,
+ padding: Union[int, tuple[int, int], Iterable[tuple[int, int]]],
+ mode: PaddingMode = PaddingMode.CONSTANT,
+ constant_values: Union[float, Iterable[float]] = 0.0,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> np.ndarray:
+ """
+ Pads the `video` with the specified (height, width) `padding` and `mode`.
+
+ Args:
+ video (`np.ndarray`):
+ The video to pad.
+ padding (`int` or `tuple[int, int]` or `Iterable[tuple[int, int]]`):
+ Padding to apply to the edges of the height, width axes. Can be one of three formats:
+ - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
+ - `((before, after),)` yields same before and after pad for height and width.
+ - `(pad,)` or int is a shortcut for before = after = pad width for all axes.
+ mode (`PaddingMode`):
+ The padding mode to use. Can be one of:
+ - `"constant"`: pads with a constant value.
+ - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
+ vector along each axis.
+ - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
+ - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
+ constant_values (`float` or `Iterable[float]`, *optional*):
+ The value to use for the padding if `mode` is `"constant"`.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output video. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_frames, num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: video in (num_frames, height, width, num_channels) format.
+ If unset, will use same as the input video.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input video. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_frames, num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: video in (num_frames, height, width, num_channels) format.
+ If unset, will use the inferred format of the input video.
+
+ Returns:
+ `np.ndarray`: The padded video.
+
+ """
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(video)
+
+ def _expand_for_data_format(values):
+ """
+ Convert values to be in the format expected by np.pad based on the data format.
+ """
+ if isinstance(values, (int, float)):
+ values = ((values, values), (values, values))
+ elif isinstance(values, tuple) and len(values) == 1:
+ values = ((values[0], values[0]), (values[0], values[0]))
+ elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int):
+ values = (values, values)
+ elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple):
+ pass
+ else:
+ raise ValueError(f"Unsupported format: {values}")
+
+ # add 0 for channel dimension
+ values = (
+ ((0, 0), (0, 0), *values) if input_data_format == ChannelDimension.FIRST else ((0, 0), *values, (0, 0))
+ )
+
+ # Add additional padding if there's a batch dimension
+ values = (0, *values) if video.ndim == 5 else values
+ return values
+
+ padding_map = {
+ PaddingMode.CONSTANT: "constant",
+ PaddingMode.REFLECT: "reflect",
+ PaddingMode.REPLICATE: "replicate",
+ PaddingMode.SYMMETRIC: "symmetric",
+ }
+ padding = _expand_for_data_format(padding)
+
+ pad_kwargs = {}
+ if mode not in padding_map:
+ raise ValueError(f"Invalid padding mode: {mode}")
+ elif mode == PaddingMode.CONSTANT:
+ pad_kwargs["constant_values"] = _expand_for_data_format(constant_values)
+
+ video = np.pad(video, padding, mode=padding_map[mode], **pad_kwargs)
+ video = to_channel_dimension_format(video, data_format, input_data_format) if data_format is not None else video
+ return video
+
+
+def group_videos_by_shape(
+ videos: list["torch.Tensor"],
+) -> tuple[dict[tuple[int, int], "torch.Tensor"], dict[int, tuple[tuple[int, int], int]]]:
+ """
+ Groups videos by shape.
+ Returns a dictionary with the shape as key and a list of videos with that shape as value,
+ and a dictionary with the index of the video in the original list as key and the shape and index in the grouped list as value.
+ """
+ grouped_videos = {}
+ grouped_videos_index = {}
+ for i, video in enumerate(videos):
+ shape = video.shape[-2::]
+ num_frames = video.shape[-4] # video format BTCHW
+ shape = (num_frames, *shape)
+ if shape not in grouped_videos:
+ grouped_videos[shape] = []
+ grouped_videos[shape].append(video)
+ grouped_videos_index[i] = (shape, len(grouped_videos[shape]) - 1)
+ # stack videos with the same size and number of frames
+ grouped_videos = {shape: torch.stack(videos, dim=0) for shape, videos in grouped_videos.items()}
+ return grouped_videos, grouped_videos_index
+
+
+def reorder_videos(
+ processed_videos: dict[tuple[int, int], "torch.Tensor"],
+ grouped_videos_index: dict[int, tuple[tuple[int, int], int]],
+) -> list["torch.Tensor"]:
+ """
+ Reconstructs a list of videos in the original order.
+ """
+ return [
+ processed_videos[grouped_videos_index[i][0]][grouped_videos_index[i][1]]
+ for i in range(len(grouped_videos_index))
+ ]