diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/ElementInclude.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/ElementInclude.py new file mode 100644 index 0000000000000000000000000000000000000000..21884336f534cd2013165934111146684c9909cf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/ElementInclude.py @@ -0,0 +1,244 @@ +# +# ElementTree +# $Id: ElementInclude.py 1862 2004-06-18 07:31:02Z Fredrik $ +# +# limited xinclude support for element trees +# +# history: +# 2003-08-15 fl created +# 2003-11-14 fl fixed default loader +# +# Copyright (c) 2003-2004 by Fredrik Lundh. All rights reserved. +# +# fredrik@pythonware.com +# http://www.pythonware.com +# +# -------------------------------------------------------------------- +# The ElementTree toolkit is +# +# Copyright (c) 1999-2004 by Fredrik Lundh +# +# By obtaining, using, and/or copying this software and/or its +# associated documentation, you agree that you have read, understood, +# and will comply with the following terms and conditions: +# +# Permission to use, copy, modify, and distribute this software and +# its associated documentation for any purpose and without fee is +# hereby granted, provided that the above copyright notice appears in +# all copies, and that both that copyright notice and this permission +# notice appear in supporting documentation, and that the name of +# Secret Labs AB or the author not be used in advertising or publicity +# pertaining to distribution of the software without specific, written +# prior permission. +# +# SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD +# TO THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANT- +# ABILITY AND FITNESS. IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR +# BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY +# DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, +# WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS +# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE +# OF THIS SOFTWARE. +# -------------------------------------------------------------------- + +""" +Limited XInclude support for the ElementTree package. + +While lxml.etree has full support for XInclude (see +`etree.ElementTree.xinclude()`), this module provides a simpler, pure +Python, ElementTree compatible implementation that supports a simple +form of custom URL resolvers. +""" + +from lxml import etree +try: + from urlparse import urljoin + from urllib2 import urlopen +except ImportError: + # Python 3 + from urllib.parse import urljoin + from urllib.request import urlopen + +XINCLUDE = "{http://www.w3.org/2001/XInclude}" + +XINCLUDE_INCLUDE = XINCLUDE + "include" +XINCLUDE_FALLBACK = XINCLUDE + "fallback" +XINCLUDE_ITER_TAG = XINCLUDE + "*" + +# For security reasons, the inclusion depth is limited to this read-only value by default. +DEFAULT_MAX_INCLUSION_DEPTH = 6 + + +## +# Fatal include error. + +class FatalIncludeError(etree.LxmlSyntaxError): + pass + + +class LimitedRecursiveIncludeError(FatalIncludeError): + pass + + +## +# ET compatible default loader. +# This loader reads an included resource from disk. +# +# @param href Resource reference. +# @param parse Parse mode. Either "xml" or "text". +# @param encoding Optional text encoding. +# @return The expanded resource. If the parse mode is "xml", this +# is an ElementTree instance. If the parse mode is "text", this +# is a Unicode string. If the loader fails, it can return None +# or raise an IOError exception. +# @throws IOError If the loader fails to load the resource. + +def default_loader(href, parse, encoding=None): + file = open(href, 'rb') + if parse == "xml": + data = etree.parse(file).getroot() + else: + data = file.read() + if not encoding: + encoding = 'utf-8' + data = data.decode(encoding) + file.close() + return data + + +## +# Default loader used by lxml.etree - handles custom resolvers properly +# + +def _lxml_default_loader(href, parse, encoding=None, parser=None): + if parse == "xml": + data = etree.parse(href, parser).getroot() + else: + if "://" in href: + f = urlopen(href) + else: + f = open(href, 'rb') + data = f.read() + f.close() + if not encoding: + encoding = 'utf-8' + data = data.decode(encoding) + return data + + +## +# Wrapper for ET compatibility - drops the parser + +def _wrap_et_loader(loader): + def load(href, parse, encoding=None, parser=None): + return loader(href, parse, encoding) + return load + + +## +# Expand XInclude directives. +# +# @param elem Root element. +# @param loader Optional resource loader. If omitted, it defaults +# to {@link default_loader}. If given, it should be a callable +# that implements the same interface as default_loader. +# @param base_url The base URL of the original file, to resolve +# relative include file references. +# @param max_depth The maximum number of recursive inclusions. +# Limited to reduce the risk of malicious content explosion. +# Pass None to disable the limitation. +# @throws LimitedRecursiveIncludeError If the {@link max_depth} was exceeded. +# @throws FatalIncludeError If the function fails to include a given +# resource, or if the tree contains malformed XInclude elements. +# @throws IOError If the function fails to load a given resource. +# @returns the node or its replacement if it was an XInclude node + +def include(elem, loader=None, base_url=None, + max_depth=DEFAULT_MAX_INCLUSION_DEPTH): + if max_depth is None: + max_depth = -1 + elif max_depth < 0: + raise ValueError("expected non-negative depth or None for 'max_depth', got %r" % max_depth) + + if base_url is None: + if hasattr(elem, 'getroot'): + tree = elem + elem = elem.getroot() + else: + tree = elem.getroottree() + if hasattr(tree, 'docinfo'): + base_url = tree.docinfo.URL + elif hasattr(elem, 'getroot'): + elem = elem.getroot() + _include(elem, loader, base_url, max_depth) + + +def _include(elem, loader=None, base_url=None, + max_depth=DEFAULT_MAX_INCLUSION_DEPTH, _parent_hrefs=None): + if loader is not None: + load_include = _wrap_et_loader(loader) + else: + load_include = _lxml_default_loader + + if _parent_hrefs is None: + _parent_hrefs = set() + + parser = elem.getroottree().parser + + include_elements = list( + elem.iter(XINCLUDE_ITER_TAG)) + + for e in include_elements: + if e.tag == XINCLUDE_INCLUDE: + # process xinclude directive + href = urljoin(base_url, e.get("href")) + parse = e.get("parse", "xml") + parent = e.getparent() + if parse == "xml": + if href in _parent_hrefs: + raise FatalIncludeError( + "recursive include of %r detected" % href + ) + if max_depth == 0: + raise LimitedRecursiveIncludeError( + "maximum xinclude depth reached when including file %s" % href) + node = load_include(href, parse, parser=parser) + if node is None: + raise FatalIncludeError( + "cannot load %r as %r" % (href, parse) + ) + node = _include(node, loader, href, max_depth - 1, {href} | _parent_hrefs) + if e.tail: + node.tail = (node.tail or "") + e.tail + if parent is None: + return node # replaced the root node! + parent.replace(e, node) + elif parse == "text": + text = load_include(href, parse, encoding=e.get("encoding")) + if text is None: + raise FatalIncludeError( + "cannot load %r as %r" % (href, parse) + ) + predecessor = e.getprevious() + if predecessor is not None: + predecessor.tail = (predecessor.tail or "") + text + elif parent is None: + return text # replaced the root node! + else: + parent.text = (parent.text or "") + text + (e.tail or "") + parent.remove(e) + else: + raise FatalIncludeError( + "unknown parse type in xi:include tag (%r)" % parse + ) + elif e.tag == XINCLUDE_FALLBACK: + parent = e.getparent() + if parent is not None and parent.tag != XINCLUDE_INCLUDE: + raise FatalIncludeError( + "xi:fallback tag must be child of xi:include (%r)" % e.tag + ) + else: + raise FatalIncludeError( + "Invalid element found in XInclude namespace (%r)" % e.tag + ) + return elem diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..58c2133db7929faeafa4fd61c314a378ed4f3977 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/__init__.py @@ -0,0 +1,22 @@ +# this is a package + +__version__ = "6.0.2" + + +def get_include(): + """ + Returns a list of header include paths (for lxml itself, libxml2 + and libxslt) needed to compile C code against lxml if it was built + with statically linked libraries. + """ + import os + lxml_path = __path__[0] + include_path = os.path.join(lxml_path, 'includes') + includes = [include_path, lxml_path] + + for name in os.listdir(include_path): + path = os.path.join(include_path, name) + if os.path.isdir(path): + includes.append(path) + + return includes diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/_elementpath.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/_elementpath.py new file mode 100644 index 0000000000000000000000000000000000000000..760a1e00b8e1611e3915085482af1d9efc3a6ae1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/_elementpath.py @@ -0,0 +1,343 @@ +# cython: language_level=3 + +# +# ElementTree +# $Id: ElementPath.py 3375 2008-02-13 08:05:08Z fredrik $ +# +# limited xpath support for element trees +# +# history: +# 2003-05-23 fl created +# 2003-05-28 fl added support for // etc +# 2003-08-27 fl fixed parsing of periods in element names +# 2007-09-10 fl new selection engine +# 2007-09-12 fl fixed parent selector +# 2007-09-13 fl added iterfind; changed findall to return a list +# 2007-11-30 fl added namespaces support +# 2009-10-30 fl added child element value filter +# +# Copyright (c) 2003-2009 by Fredrik Lundh. All rights reserved. +# +# fredrik@pythonware.com +# http://www.pythonware.com +# +# -------------------------------------------------------------------- +# The ElementTree toolkit is +# +# Copyright (c) 1999-2009 by Fredrik Lundh +# +# By obtaining, using, and/or copying this software and/or its +# associated documentation, you agree that you have read, understood, +# and will comply with the following terms and conditions: +# +# Permission to use, copy, modify, and distribute this software and +# its associated documentation for any purpose and without fee is +# hereby granted, provided that the above copyright notice appears in +# all copies, and that both that copyright notice and this permission +# notice appear in supporting documentation, and that the name of +# Secret Labs AB or the author not be used in advertising or publicity +# pertaining to distribution of the software without specific, written +# prior permission. +# +# SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD +# TO THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANT- +# ABILITY AND FITNESS. IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR +# BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY +# DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, +# WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS +# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE +# OF THIS SOFTWARE. +# -------------------------------------------------------------------- + +## +# Implementation module for XPath support. There's usually no reason +# to import this module directly; the ElementTree does this for +# you, if needed. +## + + +import re + +xpath_tokenizer_re = re.compile( + "(" + "'[^']*'|\"[^\"]*\"|" + "::|" + "//?|" + r"\.\.|" + r"\(\)|" + r"[/.*:\[\]\(\)@=])|" + r"((?:\{[^}]+\})?[^/\[\]\(\)@=\s]+)|" + r"\s+" + ) + +def xpath_tokenizer(pattern, namespaces=None, with_prefixes=True): + # ElementTree uses '', lxml used None originally. + default_namespace = (namespaces.get(None) or namespaces.get('')) if namespaces else None + parsing_attribute = False + for token in xpath_tokenizer_re.findall(pattern): + ttype, tag = token + if tag and tag[0] != "{": + if ":" in tag and with_prefixes: + prefix, uri = tag.split(":", 1) + try: + if not namespaces: + raise KeyError + yield ttype, "{%s}%s" % (namespaces[prefix], uri) + except KeyError: + raise SyntaxError("prefix %r not found in prefix map" % prefix) + elif tag.isdecimal(): + yield token # index + elif default_namespace and not parsing_attribute: + yield ttype, "{%s}%s" % (default_namespace, tag) + else: + yield token + parsing_attribute = False + else: + yield token + parsing_attribute = ttype == '@' + + +def prepare_child(next, token): + tag = token[1] + def select(result): + for elem in result: + yield from elem.iterchildren(tag) + return select + +def prepare_star(next, token): + def select(result): + for elem in result: + yield from elem.iterchildren('*') + return select + +def prepare_self(next, token): + def select(result): + return result + return select + +def prepare_descendant(next, token): + token = next() + if token[0] == "*": + tag = "*" + elif not token[0]: + tag = token[1] + else: + raise SyntaxError("invalid descendant") + def select(result): + for elem in result: + yield from elem.iterdescendants(tag) + return select + +def prepare_parent(next, token): + def select(result): + for elem in result: + parent = elem.getparent() + if parent is not None: + yield parent + return select + +def prepare_predicate(next, token): + # FIXME: replace with real parser!!! refs: + # http://effbot.org/zone/simple-iterator-parser.htm + # http://javascript.crockford.com/tdop/tdop.html + signature = '' + predicate = [] + while 1: + token = next() + if token[0] == "]": + break + if token == ('', ''): + # ignore whitespace + continue + if token[0] and token[0][:1] in "'\"": + token = "'", token[0][1:-1] + signature += token[0] or "-" + predicate.append(token[1]) + + # use signature to determine predicate type + if signature == "@-": + # [@attribute] predicate + key = predicate[1] + def select(result): + for elem in result: + if elem.get(key) is not None: + yield elem + return select + if signature == "@-='": + # [@attribute='value'] + key = predicate[1] + value = predicate[-1] + def select(result): + for elem in result: + if elem.get(key) == value: + yield elem + return select + if signature == "-" and not re.match(r"-?\d+$", predicate[0]): + # [tag] + tag = predicate[0] + def select(result): + for elem in result: + for _ in elem.iterchildren(tag): + yield elem + break + return select + if signature == ".='" or (signature == "-='" and not re.match(r"-?\d+$", predicate[0])): + # [.='value'] or [tag='value'] + tag = predicate[0] + value = predicate[-1] + if tag: + def select(result): + for elem in result: + for e in elem.iterchildren(tag): + if "".join(e.itertext()) == value: + yield elem + break + else: + def select(result): + for elem in result: + if "".join(elem.itertext()) == value: + yield elem + return select + if signature == "-" or signature == "-()" or signature == "-()-": + # [index] or [last()] or [last()-index] + if signature == "-": + # [index] + index = int(predicate[0]) - 1 + if index < 0: + if index == -1: + raise SyntaxError( + "indices in path predicates are 1-based, not 0-based") + else: + raise SyntaxError("path index >= 1 expected") + else: + if predicate[0] != "last": + raise SyntaxError("unsupported function") + if signature == "-()-": + try: + index = int(predicate[2]) - 1 + except ValueError: + raise SyntaxError("unsupported expression") + else: + index = -1 + def select(result): + for elem in result: + parent = elem.getparent() + if parent is None: + continue + try: + # FIXME: what if the selector is "*" ? + elems = list(parent.iterchildren(elem.tag)) + if elems[index] is elem: + yield elem + except IndexError: + pass + return select + raise SyntaxError("invalid predicate") + +ops = { + "": prepare_child, + "*": prepare_star, + ".": prepare_self, + "..": prepare_parent, + "//": prepare_descendant, + "[": prepare_predicate, +} + + +# -------------------------------------------------------------------- + +_cache = {} + + +def _build_path_iterator(path, namespaces, with_prefixes=True): + """compile selector pattern""" + if path[-1:] == "/": + path += "*" # implicit all (FIXME: keep this?) + + cache_key = (path,) + if namespaces: + # lxml originally used None for the default namespace but ElementTree uses the + # more convenient (all-strings-dict) empty string, so we support both here, + # preferring the more convenient '', as long as they aren't ambiguous. + if None in namespaces: + if '' in namespaces and namespaces[None] != namespaces['']: + raise ValueError("Ambiguous default namespace provided: %r versus %r" % ( + namespaces[None], namespaces[''])) + cache_key += (namespaces[None],) + tuple(sorted( + item for item in namespaces.items() if item[0] is not None)) + else: + cache_key += tuple(sorted(namespaces.items())) + + try: + return _cache[cache_key] + except KeyError: + pass + if len(_cache) > 100: + _cache.clear() + + if path[:1] == "/": + raise SyntaxError("cannot use absolute path on element") + stream = iter(xpath_tokenizer(path, namespaces, with_prefixes=with_prefixes)) + try: + _next = stream.next + except AttributeError: + # Python 3 + _next = stream.__next__ + try: + token = _next() + except StopIteration: + raise SyntaxError("empty path expression") + selector = [] + while 1: + try: + selector.append(ops[token[0]](_next, token)) + except StopIteration: + raise SyntaxError("invalid path") + try: + token = _next() + if token[0] == "/": + token = _next() + except StopIteration: + break + _cache[cache_key] = selector + return selector + + +## +# Iterate over the matching nodes + +def iterfind(elem, path, namespaces=None, with_prefixes=True): + selector = _build_path_iterator(path, namespaces, with_prefixes=with_prefixes) + result = iter((elem,)) + for select in selector: + result = select(result) + return result + + +## +# Find first matching object. + +def find(elem, path, namespaces=None, with_prefixes=True): + it = iterfind(elem, path, namespaces, with_prefixes=with_prefixes) + try: + return next(it) + except StopIteration: + return None + + +## +# Find all matching objects. + +def findall(elem, path, namespaces=None, with_prefixes=True): + return list(iterfind(elem, path, namespaces)) + + +## +# Find text for first matching object. + +def findtext(elem, path, default=None, namespaces=None, with_prefixes=True): + el = find(elem, path, namespaces, with_prefixes=with_prefixes) + if el is None: + return default + else: + return el.text or '' diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/apihelpers.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/apihelpers.pxi new file mode 100644 index 0000000000000000000000000000000000000000..f683e70db95645fe22e1af17a52efc1993a352f0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/apihelpers.pxi @@ -0,0 +1,1801 @@ +# Private/public helper functions for API functions + +from lxml.includes cimport uri + + +cdef void displayNode(xmlNode* c_node, indent) noexcept: + # to help with debugging + cdef xmlNode* c_child + try: + print(indent * ' ', c_node) + c_child = c_node.children + while c_child is not NULL: + displayNode(c_child, indent + 1) + c_child = c_child.next + finally: + return # swallow any exceptions + +cdef inline bint _isHtmlDocument(_Element element) except -1: + cdef xmlNode* c_node = element._c_node + return ( + c_node is not NULL and c_node.doc is not NULL and + c_node.doc.properties & tree.XML_DOC_HTML != 0 + ) + +cdef inline int _assertValidNode(_Element element) except -1: + assert element._c_node is not NULL, "invalid Element proxy at %s" % id(element) + +cdef inline int _assertValidDoc(_Document doc) except -1: + assert doc._c_doc is not NULL, "invalid Document proxy at %s" % id(doc) + +cdef _Document _documentOrRaise(object input): + """Call this to get the document of a _Document, _ElementTree or _Element + object, or to raise an exception if it can't be determined. + + Should be used in all API functions for consistency. + """ + cdef _Document doc + if isinstance(input, _ElementTree): + if (<_ElementTree>input)._context_node is not None: + doc = (<_ElementTree>input)._context_node._doc + else: + doc = None + elif isinstance(input, _Element): + doc = (<_Element>input)._doc + elif isinstance(input, _Document): + doc = <_Document>input + else: + raise TypeError, f"Invalid input object: {python._fqtypename(input).decode('utf8')}" + if doc is None: + raise ValueError, f"Input object has no document: {python._fqtypename(input).decode('utf8')}" + _assertValidDoc(doc) + return doc + +cdef _Element _rootNodeOrRaise(object input): + """Call this to get the root node of a _Document, _ElementTree or + _Element object, or to raise an exception if it can't be determined. + + Should be used in all API functions for consistency. + """ + cdef _Element node + if isinstance(input, _ElementTree): + node = (<_ElementTree>input)._context_node + elif isinstance(input, _Element): + node = <_Element>input + elif isinstance(input, _Document): + node = (<_Document>input).getroot() + else: + raise TypeError, f"Invalid input object: {python._fqtypename(input).decode('utf8')}" + if (node is None or not node._c_node or + node._c_node.type != tree.XML_ELEMENT_NODE): + raise ValueError, f"Input object is not an XML element: {python._fqtypename(input).decode('utf8')}" + _assertValidNode(node) + return node + +cdef bint _isAncestorOrSame(xmlNode* c_ancestor, xmlNode* c_node) noexcept: + while c_node: + if c_node is c_ancestor: + return True + c_node = c_node.parent + return False + +cdef _Element _makeElement(tag, xmlDoc* c_doc, _Document doc, + _BaseParser parser, text, tail, attrib, nsmap, + dict extra_attrs): + """Create a new element and initialize text content, namespaces and + attributes. + + This helper function will reuse as much of the existing document as + possible: + + If 'parser' is None, the parser will be inherited from 'doc' or the + default parser will be used. + + If 'doc' is None, 'c_doc' is used to create a new _Document and the new + element is made its root node. + + If 'c_doc' is also NULL, a new xmlDoc will be created. + """ + cdef xmlNode* c_node + if doc is not None: + c_doc = doc._c_doc + ns_utf, name_utf = _getNsTag(tag) + if parser is not None and parser._for_html: + _htmlTagValidOrRaise(name_utf) + if c_doc is NULL: + c_doc = _newHTMLDoc() + else: + _tagValidOrRaise(name_utf) + if c_doc is NULL: + c_doc = _newXMLDoc() + c_node = _createElement(c_doc, name_utf) + if c_node is NULL: + if doc is None and c_doc is not NULL: + tree.xmlFreeDoc(c_doc) + raise MemoryError() + try: + if doc is None: + tree.xmlDocSetRootElement(c_doc, c_node) + doc = _documentFactory(c_doc, parser) + if text is not None: + _setNodeText(c_node, text) + if tail is not None: + _setTailText(c_node, tail) + # add namespaces to node if necessary + _setNodeNamespaces(c_node, doc, ns_utf, nsmap) + _initNodeAttributes(c_node, doc, attrib, extra_attrs) + return _elementFactory(doc, c_node) + except: + # free allocated c_node/c_doc unless Python does it for us + if c_node.doc is not c_doc: + # node not yet in document => will not be freed by document + if tail is not None: + _removeText(c_node.next) # tail + tree.xmlFreeNode(c_node) + if doc is None: + # c_doc will not be freed by doc + tree.xmlFreeDoc(c_doc) + raise + +cdef int _initNewElement(_Element element, bint is_html, name_utf, ns_utf, + _BaseParser parser, attrib, nsmap, dict extra_attrs) except -1: + """Initialise a new Element object. + + This is used when users instantiate a Python Element subclass + directly, without it being mapped to an existing XML node. + """ + cdef xmlDoc* c_doc + cdef xmlNode* c_node + cdef _Document doc + if is_html: + _htmlTagValidOrRaise(name_utf) + c_doc = _newHTMLDoc() + else: + _tagValidOrRaise(name_utf) + c_doc = _newXMLDoc() + c_node = _createElement(c_doc, name_utf) + if c_node is NULL: + if c_doc is not NULL: + tree.xmlFreeDoc(c_doc) + raise MemoryError() + tree.xmlDocSetRootElement(c_doc, c_node) + doc = _documentFactory(c_doc, parser) + # add namespaces to node if necessary + _setNodeNamespaces(c_node, doc, ns_utf, nsmap) + _initNodeAttributes(c_node, doc, attrib, extra_attrs) + _registerProxy(element, doc, c_node) + element._init() + return 0 + +cdef _Element _makeSubElement(_Element parent, tag, text, tail, + attrib, nsmap, dict extra_attrs): + """Create a new child element and initialize text content, namespaces and + attributes. + """ + cdef xmlNode* c_node + cdef xmlDoc* c_doc + if parent is None or parent._doc is None: + return None + _assertValidNode(parent) + ns_utf, name_utf = _getNsTag(tag) + c_doc = parent._doc._c_doc + + if parent._doc._parser is not None and parent._doc._parser._for_html: + _htmlTagValidOrRaise(name_utf) + else: + _tagValidOrRaise(name_utf) + + c_node = _createElement(c_doc, name_utf) + if c_node is NULL: + raise MemoryError() + tree.xmlAddChild(parent._c_node, c_node) + + try: + if text is not None: + _setNodeText(c_node, text) + if tail is not None: + _setTailText(c_node, tail) + + # add namespaces to node if necessary + _setNodeNamespaces(c_node, parent._doc, ns_utf, nsmap) + _initNodeAttributes(c_node, parent._doc, attrib, extra_attrs) + return _elementFactory(parent._doc, c_node) + except: + # make sure we clean up in case of an error + _removeNode(parent._doc, c_node) + raise + + +cdef int _setNodeNamespaces(xmlNode* c_node, _Document doc, + object node_ns_utf, object nsmap) except -1: + """Lookup current namespace prefixes, then set namespace structure for + node (if 'node_ns_utf' was provided) and register new ns-prefix mappings. + + 'node_ns_utf' should only be passed for a newly created node. + """ + cdef xmlNs* c_ns + cdef list nsdefs + + if nsmap: + for prefix, href in _iter_nsmap(nsmap): + href_utf = _utf8(href) + _uriValidOrRaise(href_utf) + c_href = _xcstr(href_utf) + if prefix is not None: + prefix_utf = _utf8(prefix) + _prefixValidOrRaise(prefix_utf) + c_prefix = _xcstr(prefix_utf) + else: + c_prefix = NULL + # add namespace with prefix if it is not already known + c_ns = tree.xmlSearchNs(doc._c_doc, c_node, c_prefix) + if c_ns is NULL or \ + c_ns.href is NULL or \ + tree.xmlStrcmp(c_ns.href, c_href) != 0: + c_ns = tree.xmlNewNs(c_node, c_href, c_prefix) + if href_utf == node_ns_utf: + tree.xmlSetNs(c_node, c_ns) + node_ns_utf = None + + if node_ns_utf is not None: + _uriValidOrRaise(node_ns_utf) + doc._setNodeNs(c_node, _xcstr(node_ns_utf)) + return 0 + + +cdef dict _build_nsmap(xmlNode* c_node): + """ + Namespace prefix->URI mapping known in the context of this Element. + This includes all namespace declarations of the parents. + """ + cdef xmlNs* c_ns + nsmap = {} + while c_node is not NULL and c_node.type == tree.XML_ELEMENT_NODE: + c_ns = c_node.nsDef + while c_ns is not NULL: + if c_ns.prefix or c_ns.href: + prefix = funicodeOrNone(c_ns.prefix) + if prefix not in nsmap: + nsmap[prefix] = funicodeOrNone(c_ns.href) + c_ns = c_ns.next + c_node = c_node.parent + return nsmap + + +cdef _iter_nsmap(nsmap): + """ + Create a reproducibly ordered iterable from an nsmap mapping. + Tries to preserve an existing order and sorts if it assumes no order. + + The difference to _iter_attrib() is that None doesn't sort with strings + in Py3.x. + """ + if isinstance(nsmap, dict): + # dicts are insertion-ordered in Py3.6+ => keep the user provided order. + return nsmap.items() + if len(nsmap) <= 1: + return nsmap.items() + # nsmap will usually be a plain unordered dict => avoid type checking overhead + if type(nsmap) is not dict and isinstance(nsmap, OrderedDict): + return nsmap.items() # keep existing order + if None not in nsmap: + return sorted(nsmap.items()) + + # Move the default namespace to the end. This makes sure libxml2 + # prefers a prefix if the ns is defined redundantly on the same + # element. That way, users can work around a problem themselves + # where default namespace attributes on non-default namespaced + # elements serialise without prefix (i.e. into the non-default + # namespace). + default_ns = nsmap[None] + nsdefs = [(k, v) for k, v in nsmap.items() if k is not None] + nsdefs.sort() + nsdefs.append((None, default_ns)) + return nsdefs + + +cdef _iter_attrib(attrib): + """ + Create a reproducibly ordered iterable from an attrib mapping. + Tries to preserve an existing order and sorts if it assumes no order. + """ + # dicts are insertion-ordered in Py3.6+ => keep the user provided order. + if isinstance(attrib, (dict, _Attrib, OrderedDict)): + return attrib.items() + # assume it's an unordered mapping of some kind + return sorted(attrib.items()) + + +cdef _initNodeAttributes(xmlNode* c_node, _Document doc, attrib, dict extra): + """Initialise the attributes of an element node. + """ + cdef bint is_html + cdef xmlNs* c_ns + if attrib is not None and not hasattr(attrib, 'items'): + raise TypeError, f"Invalid attribute dictionary: {python._fqtypename(attrib).decode('utf8')}" + if not attrib and not extra: + return # nothing to do + is_html = doc._parser._for_html + seen = set() + if extra: + for name, value in extra.items(): + _addAttributeToNode(c_node, doc, is_html, name, value, seen) + if attrib: + for name, value in _iter_attrib(attrib): + _addAttributeToNode(c_node, doc, is_html, name, value, seen) + + +cdef int _addAttributeToNode(xmlNode* c_node, _Document doc, bint is_html, + name, value, set seen_tags) except -1: + ns_utf, name_utf = tag = _getNsTag(name) + if tag in seen_tags: + return 0 + seen_tags.add(tag) + if not is_html: + _attributeValidOrRaise(name_utf) + value_utf = _utf8(value) + if ns_utf is None: + tree.xmlNewProp(c_node, _xcstr(name_utf), _xcstr(value_utf)) + else: + _uriValidOrRaise(ns_utf) + c_ns = doc._findOrBuildNodeNs(c_node, _xcstr(ns_utf), NULL, 1) + tree.xmlNewNsProp(c_node, c_ns, + _xcstr(name_utf), _xcstr(value_utf)) + return 0 + + +ctypedef struct _ns_node_ref: + xmlNs* ns + xmlNode* node + + +cdef int _collectNsDefs(xmlNode* c_element, _ns_node_ref **_c_ns_list, + size_t *_c_ns_list_len, size_t *_c_ns_list_size) except -1: + c_ns_list = _c_ns_list[0] + cdef size_t c_ns_list_len = _c_ns_list_len[0] + cdef size_t c_ns_list_size = _c_ns_list_size[0] + + c_nsdef = c_element.nsDef + while c_nsdef is not NULL: + if c_ns_list_len >= c_ns_list_size: + if c_ns_list is NULL: + c_ns_list_size = 20 + else: + c_ns_list_size *= 2 + c_nsref_ptr = <_ns_node_ref*> python.lxml_realloc( + c_ns_list, c_ns_list_size, sizeof(_ns_node_ref)) + if c_nsref_ptr is NULL: + if c_ns_list is not NULL: + python.lxml_free(c_ns_list) + _c_ns_list[0] = NULL + raise MemoryError() + c_ns_list = c_nsref_ptr + + c_ns_list[c_ns_list_len] = _ns_node_ref(c_nsdef, c_element) + c_ns_list_len += 1 + c_nsdef = c_nsdef.next + + _c_ns_list_size[0] = c_ns_list_size + _c_ns_list_len[0] = c_ns_list_len + _c_ns_list[0] = c_ns_list + + +cdef int _removeUnusedNamespaceDeclarations(xmlNode* c_element, set prefixes_to_keep) except -1: + """Remove any namespace declarations from a subtree that are not used by + any of its elements (or attributes). + + If a 'prefixes_to_keep' is provided, it must be a set of prefixes. + Any corresponding namespace mappings will not be removed as part of the cleanup. + """ + cdef xmlNode* c_node + cdef _ns_node_ref* c_ns_list = NULL + cdef size_t c_ns_list_size = 0 + cdef size_t c_ns_list_len = 0 + cdef size_t i + + if c_element.parent and c_element.parent.type == tree.XML_DOCUMENT_NODE: + # include declarations on the document node + _collectNsDefs(c_element.parent, &c_ns_list, &c_ns_list_len, &c_ns_list_size) + + tree.BEGIN_FOR_EACH_ELEMENT_FROM(c_element, c_element, 1) + # collect all new namespace declarations into the ns list + if c_element.nsDef: + _collectNsDefs(c_element, &c_ns_list, &c_ns_list_len, &c_ns_list_size) + + # remove all namespace declarations from the list that are referenced + if c_ns_list_len and c_element.type == tree.XML_ELEMENT_NODE: + c_node = c_element + while c_node and c_ns_list_len: + if c_node.ns: + for i in range(c_ns_list_len): + if c_node.ns is c_ns_list[i].ns: + c_ns_list_len -= 1 + c_ns_list[i] = c_ns_list[c_ns_list_len] + #c_ns_list[c_ns_list_len] = _ns_node_ref(NULL, NULL) + break + if c_node is c_element: + # continue with attributes + c_node = c_element.properties + else: + c_node = c_node.next + tree.END_FOR_EACH_ELEMENT_FROM(c_element) + + if c_ns_list is NULL: + return 0 + + # free all namespace declarations that remained in the list, + # except for those we should keep explicitly + cdef xmlNs* c_nsdef + for i in range(c_ns_list_len): + if prefixes_to_keep is not None: + if c_ns_list[i].ns.prefix and c_ns_list[i].ns.prefix in prefixes_to_keep: + continue + c_node = c_ns_list[i].node + c_nsdef = c_node.nsDef + if c_nsdef is c_ns_list[i].ns: + c_node.nsDef = c_node.nsDef.next + else: + while c_nsdef.next is not c_ns_list[i].ns: + c_nsdef = c_nsdef.next + c_nsdef.next = c_nsdef.next.next + tree.xmlFreeNs(c_ns_list[i].ns) + + if c_ns_list is not NULL: + python.lxml_free(c_ns_list) + return 0 + +cdef xmlNs* _searchNsByHref(xmlNode* c_node, const_xmlChar* c_href, bint is_attribute) noexcept: + """Search a namespace declaration that covers a node (element or + attribute). + + For attributes, try to find a prefixed namespace declaration + instead of the default namespaces. This helps in supporting + round-trips for attributes on elements with a different namespace. + """ + cdef xmlNs* c_ns + cdef xmlNs* c_default_ns = NULL + cdef xmlNode* c_element + if c_href is NULL or c_node is NULL or c_node.type == tree.XML_ENTITY_REF_NODE: + return NULL + if tree.xmlStrcmp(c_href, tree.XML_XML_NAMESPACE) == 0: + # no special cases here, let libxml2 handle this + return tree.xmlSearchNsByHref(c_node.doc, c_node, c_href) + if c_node.type == tree.XML_ATTRIBUTE_NODE: + is_attribute = 1 + while c_node is not NULL and c_node.type != tree.XML_ELEMENT_NODE: + c_node = c_node.parent + c_element = c_node + while c_node is not NULL: + if c_node.type == tree.XML_ELEMENT_NODE: + c_ns = c_node.nsDef + while c_ns is not NULL: + if c_ns.href is not NULL and tree.xmlStrcmp(c_href, c_ns.href) == 0: + if c_ns.prefix is NULL and is_attribute: + # for attributes, continue searching a named + # prefix, but keep the first default namespace + # declaration that we found + if c_default_ns is NULL: + c_default_ns = c_ns + elif tree.xmlSearchNs( + c_element.doc, c_element, c_ns.prefix) is c_ns: + # start node is in namespace scope => found! + return c_ns + c_ns = c_ns.next + if c_node is not c_element and c_node.ns is not NULL: + # optimise: the node may have the namespace itself + c_ns = c_node.ns + if c_ns.href is not NULL and tree.xmlStrcmp(c_href, c_ns.href) == 0: + if c_ns.prefix is NULL and is_attribute: + # for attributes, continue searching a named + # prefix, but keep the first default namespace + # declaration that we found + if c_default_ns is NULL: + c_default_ns = c_ns + elif tree.xmlSearchNs( + c_element.doc, c_element, c_ns.prefix) is c_ns: + # start node is in namespace scope => found! + return c_ns + c_node = c_node.parent + # nothing found => use a matching default namespace or fail + if c_default_ns is not NULL: + if tree.xmlSearchNs(c_element.doc, c_element, NULL) is c_default_ns: + return c_default_ns + return NULL + +cdef int _replaceNodeByChildren(_Document doc, xmlNode* c_node) except -1: + # NOTE: this does not deallocate the node, just unlink it! + cdef xmlNode* c_parent + cdef xmlNode* c_child + if c_node.children is NULL: + tree.xmlUnlinkNode(c_node) + return 0 + + c_parent = c_node.parent + # fix parent links of children + c_child = c_node.children + while c_child is not NULL: + c_child.parent = c_parent + c_child = c_child.next + + # fix namespace references of children if their parent's namespace + # declarations get lost + if c_node.nsDef is not NULL: + c_child = c_node.children + while c_child is not NULL: + moveNodeToDocument(doc, doc._c_doc, c_child) + c_child = c_child.next + + # fix sibling links to/from child slice + if c_node.prev is NULL: + c_parent.children = c_node.children + else: + c_node.prev.next = c_node.children + c_node.children.prev = c_node.prev + if c_node.next is NULL: + c_parent.last = c_node.last + else: + c_node.next.prev = c_node.last + c_node.last.next = c_node.next + + # unlink c_node + c_node.children = c_node.last = NULL + c_node.parent = c_node.next = c_node.prev = NULL + return 0 + +cdef unicode _attributeValue(xmlNode* c_element, xmlAttr* c_attrib_node): + c_href = _getNs(c_attrib_node) + value = tree.xmlGetNsProp(c_element, c_attrib_node.name, c_href) + try: + result = funicode(value) + finally: + tree.xmlFree(value) + return result + +cdef unicode _attributeValueFromNsName(xmlNode* c_element, + const_xmlChar* c_href, const_xmlChar* c_name): + c_result = tree.xmlGetNsProp(c_element, c_name, c_href) + if c_result is NULL: + return None + try: + result = funicode(c_result) + finally: + tree.xmlFree(c_result) + return result + +cdef object _getNodeAttributeValue(xmlNode* c_node, key, default): + ns, tag = _getNsTag(key) + c_href = NULL if ns is None else _xcstr(ns) + c_result = tree.xmlGetNsProp(c_node, _xcstr(tag), c_href) + if c_result is NULL: + # XXX free namespace that is not in use..? + return default + try: + result = funicode(c_result) + finally: + tree.xmlFree(c_result) + return result + +cdef inline object _getAttributeValue(_Element element, key, default): + return _getNodeAttributeValue(element._c_node, key, default) + +cdef int _setAttributeValue(_Element element, key, value) except -1: + cdef const_xmlChar* c_value + cdef xmlNs* c_ns + ns, tag = _getNsTag(key) + is_html = element._doc._parser._for_html + if not is_html: + _attributeValidOrRaise(tag) + c_tag = _xcstr(tag) + if value is None and is_html: + c_value = NULL + else: + if isinstance(value, QName): + value = _resolveQNameText(element, value) + else: + value = _utf8(value) + c_value = _xcstr(value) + if ns is None: + c_ns = NULL + else: + c_ns = element._doc._findOrBuildNodeNs(element._c_node, _xcstr(ns), NULL, 1) + tree.xmlSetNsProp(element._c_node, c_ns, c_tag, c_value) + return 0 + +cdef int _delAttribute(_Element element, key) except -1: + ns, tag = _getNsTag(key) + c_href = NULL if ns is None else _xcstr(ns) + if _delAttributeFromNsName(element._c_node, c_href, _xcstr(tag)): + raise KeyError, key + return 0 + +cdef int _delAttributeFromNsName(xmlNode* c_node, const_xmlChar* c_href, const_xmlChar* c_name) noexcept: + c_attr = tree.xmlHasNsProp(c_node, c_name, c_href) + if c_attr is NULL: + # XXX free namespace that is not in use..? + return -1 + tree.xmlRemoveProp(c_attr) + return 0 + +cdef list _collectAttributes(xmlNode* c_node, int collecttype): + """Collect all attributes of a node in a list. Depending on collecttype, + it collects either the name (1), the value (2) or the name-value tuples. + """ + cdef Py_ssize_t count + c_attr = c_node.properties + count = 0 + while c_attr is not NULL: + if c_attr.type == tree.XML_ATTRIBUTE_NODE: + count += 1 + c_attr = c_attr.next + + if not count: + return [] + + attributes = [None] * count + c_attr = c_node.properties + count = 0 + while c_attr is not NULL: + if c_attr.type == tree.XML_ATTRIBUTE_NODE: + if collecttype == 1: + item = _namespacedName(c_attr) + elif collecttype == 2: + item = _attributeValue(c_node, c_attr) + else: + item = (_namespacedName(c_attr), + _attributeValue(c_node, c_attr)) + attributes[count] = item + count += 1 + c_attr = c_attr.next + return attributes + +cdef object __RE_XML_ENCODING = re.compile( + r'^(<\?xml[^>]+)\s+encoding\s*=\s*["\'][^"\']*["\'](\s*\?>|)', re.U) + +cdef object __REPLACE_XML_ENCODING = __RE_XML_ENCODING.sub +cdef object __HAS_XML_ENCODING = __RE_XML_ENCODING.match + +cdef object _stripEncodingDeclaration(object xml_string): + # this is a hack to remove the XML encoding declaration from unicode + return __REPLACE_XML_ENCODING(r'\g<1>\g<2>', xml_string) + +cdef bint _hasEncodingDeclaration(object xml_string) except -1: + # check if a (unicode) string has an XML encoding declaration + return __HAS_XML_ENCODING(xml_string) is not None + +cdef inline bint _hasText(xmlNode* c_node) noexcept: + return c_node is not NULL and _textNodeOrSkip(c_node.children) is not NULL + +cdef inline bint _hasTail(xmlNode* c_node) noexcept: + return c_node is not NULL and _textNodeOrSkip(c_node.next) is not NULL + +cdef inline bint _hasNonWhitespaceTail(xmlNode* c_node) except -1: + return _hasNonWhitespaceText(c_node, tail=True) + +cdef bint _hasNonWhitespaceText(xmlNode* c_node, bint tail=False) except -1: + c_text_node = c_node and _textNodeOrSkip(c_node.next if tail else c_node.children) + if c_text_node is NULL: + return False + while c_text_node is not NULL: + if c_text_node.content[0] != c'\0' and not _collectText(c_text_node).isspace(): + return True + c_text_node = _textNodeOrSkip(c_text_node.next) + return False + +cdef unicode _collectText(xmlNode* c_node): + """Collect all text nodes and return them as a unicode string. + + Start collecting at c_node. + + If there was no text to collect, return None + """ + cdef Py_ssize_t scount + cdef xmlChar* c_text + cdef xmlNode* c_node_cur + # check for multiple text nodes + scount = 0 + c_text = NULL + c_node_cur = c_node = _textNodeOrSkip(c_node) + while c_node_cur is not NULL: + if c_node_cur.content[0] != c'\0': + c_text = c_node_cur.content + scount += 1 + c_node_cur = _textNodeOrSkip(c_node_cur.next) + + # handle two most common cases first + if c_text is NULL: + return '' if scount > 0 else None + if scount == 1: + return funicode(c_text) + + # the rest is not performance critical anymore + result = b'' + while c_node is not NULL: + result += c_node.content + c_node = _textNodeOrSkip(c_node.next) + return funicode(result) + +cdef void _removeText(xmlNode* c_node) noexcept: + """Remove all text nodes. + + Start removing at c_node. + """ + cdef xmlNode* c_next + c_node = _textNodeOrSkip(c_node) + while c_node is not NULL: + c_next = _textNodeOrSkip(c_node.next) + tree.xmlUnlinkNode(c_node) + tree.xmlFreeNode(c_node) + c_node = c_next + +cdef xmlNode* _createTextNode(xmlDoc* doc, value) except NULL: + cdef xmlNode* c_text_node + if isinstance(value, CDATA): + c_text_node = tree.xmlNewCDataBlock( + doc, _xcstr((value)._utf8_data), + python.PyBytes_GET_SIZE((value)._utf8_data)) + else: + text = _utf8(value) + c_text_node = tree.xmlNewDocText(doc, _xcstr(text)) + if not c_text_node: + raise MemoryError() + return c_text_node + +cdef int _setNodeText(xmlNode* c_node, value) except -1: + # remove all text nodes at the start first + _removeText(c_node.children) + if value is None: + return 0 + # now add new text node with value at start + c_text_node = _createTextNode(c_node.doc, value) + if c_node.children is NULL: + tree.xmlAddChild(c_node, c_text_node) + else: + tree.xmlAddPrevSibling(c_node.children, c_text_node) + return 0 + +cdef int _setTailText(xmlNode* c_node, value) except -1: + # remove all text nodes at the start first + _removeText(c_node.next) + if value is None: + return 0 + # now append new text node with value + c_text_node = _createTextNode(c_node.doc, value) + tree.xmlAddNextSibling(c_node, c_text_node) + return 0 + +cdef bytes _resolveQNameText(_Element element, value): + cdef xmlNs* c_ns + ns, tag = _getNsTag(value) + if ns is None: + return tag + else: + c_ns = element._doc._findOrBuildNodeNs( + element._c_node, _xcstr(ns), NULL, 0) + return python.PyBytes_FromFormat('%s:%s', c_ns.prefix, _cstr(tag)) + +cdef inline bint _hasChild(xmlNode* c_node) noexcept: + return c_node is not NULL and _findChildForwards(c_node, 0) is not NULL + +cdef inline Py_ssize_t _countElements(xmlNode* c_node) noexcept: + "Counts the elements within the following siblings and the node itself." + cdef Py_ssize_t count + count = 0 + while c_node is not NULL: + if _isElement(c_node): + count += 1 + c_node = c_node.next + return count + +cdef int _findChildSlice( + slice sliceobject, xmlNode* c_parent, + xmlNode** c_start_node, Py_ssize_t* c_step, Py_ssize_t* c_length) except -1: + """Resolve a children slice. + + Returns the start node, step size and the slice length in the + pointer arguments. + """ + cdef Py_ssize_t start = 0, stop = 0, childcount + childcount = _countElements(c_parent.children) + if childcount == 0: + c_start_node[0] = NULL + c_length[0] = 0 + if sliceobject.step is None: + c_step[0] = 1 + else: + python._PyEval_SliceIndex(sliceobject.step, c_step) + return 0 + python.PySlice_GetIndicesEx( + sliceobject, childcount, &start, &stop, c_step, c_length) + if start > childcount // 2: + c_start_node[0] = _findChildBackwards(c_parent, childcount - start - 1) + else: + c_start_node[0] = _findChild(c_parent, start) + return 0 + +cdef bint _isFullSlice(slice sliceobject) except -1: + """Conservative guess if this slice is a full slice as in ``s[:]``. + """ + cdef Py_ssize_t step = 0 + if sliceobject is None: + return 0 + if sliceobject.start is None and \ + sliceobject.stop is None: + if sliceobject.step is None: + return 1 + python._PyEval_SliceIndex(sliceobject.step, &step) + if step == 1: + return 1 + return 0 + return 0 + +cdef _collectChildren(_Element element): + cdef xmlNode* c_node + cdef list result = [] + c_node = element._c_node.children + if c_node is not NULL: + if not _isElement(c_node): + c_node = _nextElement(c_node) + while c_node is not NULL: + result.append(_elementFactory(element._doc, c_node)) + c_node = _nextElement(c_node) + return result + +cdef inline xmlNode* _findChild(xmlNode* c_node, Py_ssize_t index) noexcept: + if index < 0: + return _findChildBackwards(c_node, -index - 1) + else: + return _findChildForwards(c_node, index) + +cdef inline xmlNode* _findChildForwards(xmlNode* c_node, Py_ssize_t index) noexcept: + """Return child element of c_node with index, or return NULL if not found. + """ + cdef xmlNode* c_child + cdef Py_ssize_t c + c_child = c_node.children + c = 0 + while c_child is not NULL: + if _isElement(c_child): + if c == index: + return c_child + c += 1 + c_child = c_child.next + return NULL + +cdef inline xmlNode* _findChildBackwards(xmlNode* c_node, Py_ssize_t index) noexcept: + """Return child element of c_node with index, or return NULL if not found. + Search from the end. + """ + cdef xmlNode* c_child + cdef Py_ssize_t c + c_child = c_node.last + c = 0 + while c_child is not NULL: + if _isElement(c_child): + if c == index: + return c_child + c += 1 + c_child = c_child.prev + return NULL + +cdef inline xmlNode* _textNodeOrSkip(xmlNode* c_node) noexcept nogil: + """Return the node if it's a text node. Skip over ignorable nodes in a + series of text nodes. Return NULL if a non-ignorable node is found. + + This is used to skip over XInclude nodes when collecting adjacent text + nodes. + """ + while c_node is not NULL: + if c_node.type == tree.XML_TEXT_NODE or \ + c_node.type == tree.XML_CDATA_SECTION_NODE: + return c_node + elif c_node.type == tree.XML_XINCLUDE_START or \ + c_node.type == tree.XML_XINCLUDE_END: + c_node = c_node.next + else: + return NULL + return NULL + +cdef inline xmlNode* _nextElement(xmlNode* c_node) noexcept: + """Given a node, find the next sibling that is an element. + """ + if c_node is NULL: + return NULL + c_node = c_node.next + while c_node is not NULL: + if _isElement(c_node): + return c_node + c_node = c_node.next + return NULL + +cdef inline xmlNode* _previousElement(xmlNode* c_node) noexcept: + """Given a node, find the next sibling that is an element. + """ + if c_node is NULL: + return NULL + c_node = c_node.prev + while c_node is not NULL: + if _isElement(c_node): + return c_node + c_node = c_node.prev + return NULL + +cdef inline xmlNode* _parentElement(xmlNode* c_node) noexcept: + "Given a node, find the parent element." + if c_node is NULL or not _isElement(c_node): + return NULL + c_node = c_node.parent + if c_node is NULL or not _isElement(c_node): + return NULL + return c_node + +cdef inline bint _tagMatches(xmlNode* c_node, const_xmlChar* c_href, const_xmlChar* c_name) noexcept: + """Tests if the node matches namespace URI and tag name. + + A node matches if it matches both c_href and c_name. + + A node matches c_href if any of the following is true: + * c_href is NULL + * its namespace is NULL and c_href is the empty string + * its namespace string equals the c_href string + + A node matches c_name if any of the following is true: + * c_name is NULL + * its name string equals the c_name string + """ + if c_node is NULL: + return 0 + if c_node.type != tree.XML_ELEMENT_NODE: + # not an element, only succeed if we match everything + return c_name is NULL and c_href is NULL + if c_name is NULL: + if c_href is NULL: + # always match + return 1 + else: + c_node_href = _getNs(c_node) + if c_node_href is NULL: + return c_href[0] == c'\0' + else: + return tree.xmlStrcmp(c_node_href, c_href) == 0 + elif c_href is NULL: + if _getNs(c_node) is not NULL: + return 0 + return c_node.name == c_name or tree.xmlStrcmp(c_node.name, c_name) == 0 + elif c_node.name == c_name or tree.xmlStrcmp(c_node.name, c_name) == 0: + c_node_href = _getNs(c_node) + if c_node_href is NULL: + return c_href[0] == c'\0' + else: + return tree.xmlStrcmp(c_node_href, c_href) == 0 + else: + return 0 + +cdef inline bint _tagMatchesExactly(xmlNode* c_node, qname* c_qname) noexcept: + """Tests if the node matches namespace URI and tag name. + + This differs from _tagMatches() in that it does not consider a + NULL value in qname.href a wildcard, and that it expects the c_name + to be taken from the doc dict, i.e. it only compares the names by + address. + + A node matches if it matches both href and c_name of the qname. + + A node matches c_href if any of the following is true: + * its namespace is NULL and c_href is the empty string + * its namespace string equals the c_href string + + A node matches c_name if any of the following is true: + * c_name is NULL + * its name string points to the same address (!) as c_name + """ + return _nsTagMatchesExactly(_getNs(c_node), c_node.name, c_qname) + +cdef inline bint _nsTagMatchesExactly(const_xmlChar* c_node_href, + const_xmlChar* c_node_name, + qname* c_qname) noexcept: + """Tests if name and namespace URI match those of c_qname. + + This differs from _tagMatches() in that it does not consider a + NULL value in qname.href a wildcard, and that it expects the c_name + to be taken from the doc dict, i.e. it only compares the names by + address. + + A node matches if it matches both href and c_name of the qname. + + A node matches c_href if any of the following is true: + * its namespace is NULL and c_href is the empty string + * its namespace string equals the c_href string + + A node matches c_name if any of the following is true: + * c_name is NULL + * its name string points to the same address (!) as c_name + """ + cdef char* c_href + if c_qname.c_name is not NULL and c_qname.c_name is not c_node_name: + return 0 + if c_qname.href is NULL: + return 1 + c_href = python.__cstr(c_qname.href) + if c_href[0] == b'\0': + return c_node_href is NULL or c_node_href[0] == b'\0' + elif c_node_href is NULL: + return 0 + else: + return tree.xmlStrcmp(c_href, c_node_href) == 0 + +cdef Py_ssize_t _mapTagsToQnameMatchArray(xmlDoc* c_doc, list ns_tags, + qname* c_ns_tags, bint force_into_dict) except -1: + """Map a sequence of (name, namespace) pairs to a qname array for efficient + matching with _tagMatchesExactly() above. + + Note that each qname struct in the array owns its href byte string object + if it is not NULL. + """ + cdef Py_ssize_t count = 0, i, c_tag_len + cdef bytes ns, tag + cdef const_xmlChar* c_tag + + for ns, tag in ns_tags: + if tag is None: + c_tag = NULL + else: + c_tag_len = len(tag) + if c_tag_len > limits.INT_MAX: + # too long, not in the dict => not in the document + continue + elif force_into_dict: + c_tag = tree.xmlDictLookup(c_doc.dict, _xcstr(tag), c_tag_len) + if c_tag is NULL: + # clean up before raising the error + for i in xrange(count): + cpython.ref.Py_XDECREF(c_ns_tags[i].href) + raise MemoryError() + else: + c_tag = tree.xmlDictExists(c_doc.dict, _xcstr(tag), c_tag_len) + if c_tag is NULL: + # not in the dict => not in the document + continue + + c_ns_tags[count].c_name = c_tag + if ns is None: + c_ns_tags[count].href = NULL + else: + cpython.ref.Py_INCREF(ns) # keep an owned reference! + c_ns_tags[count].href = ns + count += 1 + return count + +cdef int _removeNode(_Document doc, xmlNode* c_node) except -1: + """Unlink and free a node and subnodes if possible. Otherwise, make sure + it's self-contained. + """ + cdef xmlNode* c_next + c_next = c_node.next + tree.xmlUnlinkNode(c_node) + _moveTail(c_next, c_node) + if not attemptDeallocation(c_node): + # make namespaces absolute + moveNodeToDocument(doc, c_node.doc, c_node) + return 0 + +cdef int _removeSiblings(xmlNode* c_element, tree.xmlElementType node_type, bint with_tail) except -1: + cdef xmlNode* c_node + cdef xmlNode* c_next + c_node = c_element.next + while c_node is not NULL: + c_next = _nextElement(c_node) + if c_node.type == node_type: + if with_tail: + _removeText(c_node.next) + tree.xmlUnlinkNode(c_node) + attemptDeallocation(c_node) + c_node = c_next + c_node = c_element.prev + while c_node is not NULL: + c_next = _previousElement(c_node) + if c_node.type == node_type: + if with_tail: + _removeText(c_node.next) + tree.xmlUnlinkNode(c_node) + attemptDeallocation(c_node) + c_node = c_next + return 0 + +cdef void _moveTail(xmlNode* c_tail, xmlNode* c_target) noexcept: + cdef xmlNode* c_next + # tail support: look for any text nodes trailing this node and + # move them too + c_tail = _textNodeOrSkip(c_tail) + while c_tail is not NULL: + c_next = _textNodeOrSkip(c_tail.next) + c_target = tree.xmlAddNextSibling(c_target, c_tail) + c_tail = c_next + +cdef int _copyTail(xmlNode* c_tail, xmlNode* c_target) except -1: + cdef xmlNode* c_new_tail + # tail copying support: look for any text nodes trailing this node and + # copy it to the target node + c_tail = _textNodeOrSkip(c_tail) + while c_tail is not NULL: + if c_target.doc is not c_tail.doc: + c_new_tail = tree.xmlDocCopyNode(c_tail, c_target.doc, 0) + else: + c_new_tail = tree.xmlCopyNode(c_tail, 0) + if c_new_tail is NULL: + raise MemoryError() + c_target = tree.xmlAddNextSibling(c_target, c_new_tail) + c_tail = _textNodeOrSkip(c_tail.next) + return 0 + +cdef int _copyNonElementSiblings(xmlNode* c_node, xmlNode* c_target) except -1: + cdef xmlNode* c_copy + cdef xmlNode* c_sibling = c_node + while c_sibling.prev != NULL and \ + (c_sibling.prev.type == tree.XML_PI_NODE or + c_sibling.prev.type == tree.XML_COMMENT_NODE or + c_sibling.prev.type == tree.XML_DTD_NODE): + c_sibling = c_sibling.prev + while c_sibling != c_node: + if c_sibling.type == tree.XML_DTD_NODE: + c_copy = _copyDtd(c_sibling) + if c_sibling == c_node.doc.intSubset: + c_target.doc.intSubset = c_copy + else: # c_sibling == c_node.doc.extSubset + c_target.doc.extSubset = c_copy + else: + c_copy = tree.xmlDocCopyNode(c_sibling, c_target.doc, 1) + if c_copy is NULL: + raise MemoryError() + tree.xmlAddPrevSibling(c_target, c_copy) + c_sibling = c_sibling.next + while c_sibling.next != NULL and \ + (c_sibling.next.type == tree.XML_PI_NODE or + c_sibling.next.type == tree.XML_COMMENT_NODE): + c_sibling = c_sibling.next + c_copy = tree.xmlDocCopyNode(c_sibling, c_target.doc, 1) + if c_copy is NULL: + raise MemoryError() + tree.xmlAddNextSibling(c_target, c_copy) + +cdef int _deleteSlice(_Document doc, xmlNode* c_node, + Py_ssize_t count, Py_ssize_t step) except -1: + """Delete slice, ``count`` items starting with ``c_node`` with a step + width of ``step``. + """ + cdef xmlNode* c_next + cdef Py_ssize_t c, i + cdef _node_to_node_function next_element + if c_node is NULL: + return 0 + if step > 0: + next_element = _nextElement + else: + step = -step + next_element = _previousElement + # now start deleting nodes + c = 0 + c_next = c_node + while c_node is not NULL and c < count: + for i in range(step): + c_next = next_element(c_next) + if c_next is NULL: + break + _removeNode(doc, c_node) + c += 1 + c_node = c_next + return 0 + +cdef int _replaceSlice(_Element parent, xmlNode* c_node, + Py_ssize_t slicelength, Py_ssize_t step, + bint left_to_right, elements) except -1: + """Replace the slice of ``count`` elements starting at ``c_node`` with + positive step width ``step`` by the Elements in ``elements``. The + direction is given by the boolean argument ``left_to_right``. + + ``c_node`` may be NULL to indicate the end of the children list. + """ + cdef xmlNode* c_orig_neighbour + cdef xmlNode* c_next + cdef xmlDoc* c_source_doc + cdef _Element element + cdef Py_ssize_t seqlength, i, c + cdef _node_to_node_function next_element + assert step > 0 + if left_to_right: + next_element = _nextElement + else: + next_element = _previousElement + + if not isinstance(elements, (list, tuple)): + elements = list(elements) + + if step != 1 or not left_to_right: + # *replacing* children stepwise with list => check size! + seqlength = len(elements) + if seqlength != slicelength: + raise ValueError, f"attempt to assign sequence of size {seqlength} " \ + f"to extended slice of size {slicelength}" + + if c_node is NULL: + # no children yet => add all elements straight away + if left_to_right: + for element in elements: + assert element is not None, "Node must not be None" + _appendChild(parent, element) + else: + for element in elements: + assert element is not None, "Node must not be None" + _prependChild(parent, element) + return 0 + + # remove the elements first as some might be re-added + if left_to_right: + # L->R, remember left neighbour + c_orig_neighbour = _previousElement(c_node) + else: + # R->L, remember right neighbour + c_orig_neighbour = _nextElement(c_node) + + # We remove the original slice elements one by one. Since we hold + # a Python reference to all elements that we will insert, it is + # safe to let _removeNode() try (and fail) to free them even if + # the element itself or one of its descendents will be reinserted. + c = 0 + c_next = c_node + while c_node is not NULL and c < slicelength: + for i in range(step): + c_next = next_element(c_next) + if c_next is NULL: + break + _removeNode(parent._doc, c_node) + c += 1 + c_node = c_next + + # make sure each element is inserted only once + elements = iter(elements) + + # find the first node right of the new insertion point + if left_to_right: + if c_orig_neighbour is not NULL: + c_node = next_element(c_orig_neighbour) + else: + # before the first element + c_node = _findChildForwards(parent._c_node, 0) + elif c_orig_neighbour is NULL: + # at the end, but reversed stepping + # append one element and go to the next insertion point + for element in elements: + assert element is not None, "Node must not be None" + _appendChild(parent, element) + c_node = element._c_node + if slicelength > 0: + slicelength -= 1 + for i in range(1, step): + c_node = next_element(c_node) + if c_node is NULL: + break + break + else: + c_node = c_orig_neighbour + + if left_to_right: + # adjust step size after removing slice as we are not stepping + # over the newly inserted elements + step -= 1 + + # now insert elements where we removed them + if c_node is not NULL: + for element in elements: + assert element is not None, "Node must not be None" + _assertValidNode(element) + # move element and tail over + c_source_doc = element._c_node.doc + c_next = element._c_node.next + tree.xmlAddPrevSibling(c_node, element._c_node) + _moveTail(c_next, element._c_node) + + # integrate element into new document + moveNodeToDocument(parent._doc, c_source_doc, element._c_node) + + # stop at the end of the slice + if slicelength > 0: + slicelength -= 1 + for i in range(step): + c_node = next_element(c_node) + if c_node is NULL: + break + if c_node is NULL: + break + else: + # everything inserted + return 0 + + # append the remaining elements at the respective end + if left_to_right: + for element in elements: + assert element is not None, "Node must not be None" + _assertValidNode(element) + _appendChild(parent, element) + else: + for element in elements: + assert element is not None, "Node must not be None" + _assertValidNode(element) + _prependChild(parent, element) + + return 0 + + +cdef int _linkChild(xmlNode* c_parent, xmlNode* c_node) except -1: + """Adaptation of 'xmlAddChild()' that deep-fix the document links iteratively. + """ + assert _isElement(c_node) + c_node.parent = c_parent + if c_parent.children is NULL: + c_parent.children = c_parent.last = c_node + else: + c_node.prev = c_parent.last + c_parent.last.next = c_node + c_parent.last = c_node + + _setTreeDoc(c_node, c_parent.doc) + return 0 + + +cdef int _appendChild(_Element parent, _Element child) except -1: + """Append a new child to a parent element. + """ + c_node = child._c_node + c_source_doc = c_node.doc + # prevent cycles + if _isAncestorOrSame(c_node, parent._c_node): + raise ValueError("cannot append parent to itself") + # store possible text node + c_next = c_node.next + # move node itself + tree.xmlUnlinkNode(c_node) + # do not call xmlAddChild() here since it would deep-traverse the tree + _linkChild(parent._c_node, c_node) + _moveTail(c_next, c_node) + # uh oh, elements may be pointing to different doc when + # parent element has moved; change them too.. + moveNodeToDocument(parent._doc, c_source_doc, c_node) + return 0 + +cdef int _prependChild(_Element parent, _Element child) except -1: + """Prepend a new child to a parent element. + """ + c_node = child._c_node + c_source_doc = c_node.doc + # prevent cycles + if _isAncestorOrSame(c_node, parent._c_node): + raise ValueError("cannot append parent to itself") + # store possible text node + c_next = c_node.next + # move node itself + c_child = _findChildForwards(parent._c_node, 0) + if c_child is NULL: + tree.xmlUnlinkNode(c_node) + # do not call xmlAddChild() here since it would deep-traverse the tree + _linkChild(parent._c_node, c_node) + else: + tree.xmlAddPrevSibling(c_child, c_node) + _moveTail(c_next, c_node) + # uh oh, elements may be pointing to different doc when + # parent element has moved; change them too.. + moveNodeToDocument(parent._doc, c_source_doc, c_node) + return 0 + +cdef int _appendSibling(_Element element, _Element sibling) except -1: + """Add a new sibling behind an element. + """ + return _addSibling(element, sibling, as_next=True) + +cdef int _prependSibling(_Element element, _Element sibling) except -1: + """Add a new sibling before an element. + """ + return _addSibling(element, sibling, as_next=False) + +cdef int _addSibling(_Element element, _Element sibling, bint as_next) except -1: + c_node = sibling._c_node + c_source_doc = c_node.doc + # prevent cycles + if _isAncestorOrSame(c_node, element._c_node): + if element._c_node is c_node: + return 0 # nothing to do + raise ValueError("cannot add ancestor as sibling, please break cycle first") + # store possible text node + c_next = c_node.next + # move node itself + if as_next: + # must insert after any tail text + c_next_node = _nextElement(element._c_node) + if c_next_node is NULL: + c_next_node = element._c_node + while c_next_node.next: + c_next_node = c_next_node.next + tree.xmlAddNextSibling(c_next_node, c_node) + else: + tree.xmlAddPrevSibling(c_next_node, c_node) + else: + tree.xmlAddPrevSibling(element._c_node, c_node) + _moveTail(c_next, c_node) + # uh oh, elements may be pointing to different doc when + # parent element has moved; change them too.. + moveNodeToDocument(element._doc, c_source_doc, c_node) + return 0 + +cdef inline bint isutf8(const_xmlChar* s) noexcept: + cdef xmlChar c = s[0] + while c != c'\0': + if c & 0x80: + return True + s += 1 + c = s[0] + return False + +cdef bint isutf8l(const_xmlChar* s, size_t length) noexcept: + """ + Search for non-ASCII characters in the string, knowing its length in advance. + """ + cdef unsigned int i + cdef unsigned long non_ascii_mask + cdef const unsigned long *lptr = s + + cdef const unsigned long *end = lptr + length // sizeof(unsigned long) + if length >= sizeof(non_ascii_mask): + # Build constant 0x80808080... mask (and let the C compiler fold it). + non_ascii_mask = 0 + for i in range(sizeof(non_ascii_mask) // 2): + non_ascii_mask = (non_ascii_mask << 16) | 0x8080 + + # Advance to long-aligned character before we start reading longs. + while (s) % sizeof(unsigned long) and s < end: + if s[0] & 0x80: + return True + s += 1 + + # Read one long at a time + lptr = s + while lptr < end: + if lptr[0] & non_ascii_mask: + return True + lptr += 1 + s = lptr + + while s < (end + length % sizeof(unsigned long)): + if s[0] & 0x80: + return True + s += 1 + + return False + +cdef int _is_valid_xml_ascii(bytes pystring) except -1: + """Check if a string is XML ascii content.""" + cdef signed char ch + # When ch is a *signed* char, non-ascii characters are negative integers + # and xmlIsChar_ch does not accept them. + for ch in pystring: + if not tree.xmlIsChar_ch(ch): + return 0 + return 1 + +cdef bint _is_valid_xml_utf8(bytes pystring) except -1: + """Check if a string is like valid UTF-8 XML content.""" + cdef const_xmlChar* s = _xcstr(pystring) + cdef const_xmlChar* c_end = s + len(pystring) + cdef unsigned long next3 = 0 + if s < c_end - 2: + next3 = (s[0] << 8) | (s[1]) + + while s < c_end - 2: + next3 = 0x00ffffff & ((next3 << 8) | s[2]) + if s[0] & 0x80: + # 0xefbfbe and 0xefbfbf are utf-8 encodings of + # forbidden characters \ufffe and \uffff + if next3 == 0x00efbfbe or next3 == 0x00efbfbf: + return 0 + # 0xeda080 and 0xedbfbf are utf-8 encodings of + # \ud800 and \udfff. Anything between them (inclusive) + # is forbidden, because they are surrogate blocks in utf-16. + if 0x00eda080 <= next3 <= 0x00edbfbf: + return 0 + elif not tree.xmlIsChar_ch(s[0]): + return 0 # invalid ascii char + s += 1 + + while s < c_end: + if not s[0] & 0x80 and not tree.xmlIsChar_ch(s[0]): + return 0 # invalid ascii char + s += 1 + + return 1 + +cdef inline unicode funicodeOrNone(const_xmlChar* s): + return funicode(s) if s is not NULL else None + +cdef inline unicode funicodeOrEmpty(const_xmlChar* s): + return funicode(s) if s is not NULL else '' + +cdef unicode funicode(const_xmlChar* s): + return s.decode('UTF-8') + +cdef bytes _utf8(object s): + """Test if a string is valid user input and encode it to UTF-8. + Reject all bytes/unicode input that contains non-XML characters. + Reject all bytes input that contains non-ASCII characters. + """ + cdef int valid + cdef bytes utf8_string + if isinstance(s, unicode): + utf8_string = (s).encode('utf8') + valid = _is_valid_xml_utf8(utf8_string) + elif isinstance(s, (bytes, bytearray)): + utf8_string = s if type(s) is bytes else bytes(s) + valid = _is_valid_xml_ascii(utf8_string) + else: + raise TypeError("Argument must be bytes or unicode, got '%.200s'" % type(s).__name__) + if not valid: + raise ValueError( + "All strings must be XML compatible: Unicode or ASCII, no NULL bytes or control characters") + return utf8_string + + +cdef bytes _utf8orNone(object s): + return _utf8(s) if s is not None else None + + +cdef enum: + NO_FILE_PATH = 0 + ABS_UNIX_FILE_PATH = 1 + ABS_WIN_FILE_PATH = 2 + REL_FILE_PATH = 3 + + +cdef bint _isFilePath(const_xmlChar* c_path) noexcept: + "simple heuristic to see if a path is a filename" + cdef xmlChar c + # test if it looks like an absolute Unix path or a Windows network path + if c_path[0] == c'/': + return ABS_UNIX_FILE_PATH + + # test if it looks like an absolute Windows path or URL + if c'a' <= c_path[0] <= c'z' or c'A' <= c_path[0] <= c'Z': + c_path += 1 + if c_path[0] == c':' and c_path[1] in b'\0\\': + return ABS_WIN_FILE_PATH # C: or C:\... + + # test if it looks like a URL with scheme:// + while c'a' <= c_path[0] <= c'z' or c'A' <= c_path[0] <= c'Z': + c_path += 1 + if c_path[0] == c':' and c_path[1] == c'/' and c_path[2] == c'/': + return NO_FILE_PATH + + # assume it's a relative path + return REL_FILE_PATH + + +cdef object _getFSPathOrObject(object obj): + """ + Get the __fspath__ attribute of an object if it exists. + Otherwise, the original object is returned. + """ + if _isString(obj): + return obj + try: + return python.PyOS_FSPath(obj) + except TypeError: + return obj + + +cdef object _encodeFilename(object filename): + """Make sure a filename is 8-bit encoded (or None). + """ + if filename is None: + return None + elif isinstance(filename, bytes): + return filename + elif isinstance(filename, unicode): + filename8 = (filename).encode('utf8') + if _isFilePath(filename8): + try: + return python.PyUnicode_AsEncodedString( + filename, _C_FILENAME_ENCODING, NULL) + except UnicodeEncodeError: + pass + return filename8 + else: + raise TypeError("Argument must be string or unicode.") + +cdef object _decodeFilename(const_xmlChar* c_path): + """Make the filename a unicode string if we are in Py3. + """ + return _decodeFilenameWithLength(c_path, tree.xmlStrlen(c_path)) + +cdef object _decodeFilenameWithLength(const_xmlChar* c_path, size_t c_len): + """Make the filename a unicode string if we are in Py3. + """ + if _isFilePath(c_path): + try: + return python.PyUnicode_Decode( + c_path, c_len, _C_FILENAME_ENCODING, NULL) + except UnicodeDecodeError: + pass + try: + return (c_path)[:c_len].decode('UTF-8') + except UnicodeDecodeError: + # this is a stupid fallback, but it might still work... + return (c_path)[:c_len].decode('latin-1', 'replace') + +cdef object _encodeFilenameUTF8(object filename): + """Recode filename as UTF-8. Tries ASCII, local filesystem encoding and + UTF-8 as source encoding. + """ + cdef char* c_filename + if filename is None: + return None + elif isinstance(filename, bytes): + if not isutf8l(filename, len(filename)): + # plain ASCII! + return filename + c_filename = _cstr(filename) + try: + # try to decode with default encoding + filename = python.PyUnicode_Decode( + c_filename, len(filename), + _C_FILENAME_ENCODING, NULL) + except UnicodeDecodeError as decode_exc: + try: + # try if it's proper UTF-8 + (filename).decode('utf8') + return filename + except UnicodeDecodeError: + raise decode_exc # otherwise re-raise original exception + if isinstance(filename, unicode): + return (filename).encode('utf8') + else: + raise TypeError("Argument must be string or unicode.") + +cdef tuple _getNsTag(tag): + """Given a tag, find namespace URI and tag name. + Return None for NS uri if no namespace URI provided. + """ + return __getNsTag(tag, 0) + +cdef tuple _getNsTagWithEmptyNs(tag): + """Given a tag, find namespace URI and tag name. Return None for NS uri + if no namespace URI provided, or the empty string if namespace + part is '{}'. + """ + return __getNsTag(tag, 1) + +cdef tuple __getNsTag(tag, bint empty_ns): + cdef char* c_tag + cdef char* c_ns_end + cdef Py_ssize_t taglen + cdef Py_ssize_t nslen + cdef bytes ns = None + # _isString() is much faster than isinstance() + if not _isString(tag) and isinstance(tag, QName): + tag = (tag).text + tag = _utf8(tag) + c_tag = _cstr(tag) + if c_tag[0] == c'{': + c_tag += 1 + c_ns_end = cstring_h.strchr(c_tag, c'}') + if c_ns_end is NULL: + raise ValueError, "Invalid tag name" + nslen = c_ns_end - c_tag + taglen = python.PyBytes_GET_SIZE(tag) - nslen - 2 + if taglen == 0: + raise ValueError, "Empty tag name" + if nslen > 0: + ns = c_tag[:nslen] + elif empty_ns: + ns = b'' + tag = c_ns_end[1:taglen+1] + elif python.PyBytes_GET_SIZE(tag) == 0: + raise ValueError, "Empty tag name" + return ns, tag + +cdef inline int _pyXmlNameIsValid(name_utf8): + return _xmlNameIsValid(_xcstr(name_utf8)) and b':' not in name_utf8 + +cdef inline int _pyHtmlNameIsValid(name_utf8): + return _htmlNameIsValid(_xcstr(name_utf8)) + +cdef inline int _xmlNameIsValid(const_xmlChar* c_name) noexcept: + return tree.xmlValidateNameValue(c_name) + +cdef int _htmlNameIsValid(const_xmlChar* c_name) noexcept: + if c_name is NULL or c_name[0] == c'\0': + return 0 + while c_name[0] != c'\0': + if c_name[0] in b'&<>/"\'\t\n\x0B\x0C\r ': + return 0 + c_name += 1 + return 1 + +cdef bint _characterReferenceIsValid(const_xmlChar* c_name) noexcept: + cdef bint is_hex + if c_name[0] == c'x': + c_name += 1 + is_hex = 1 + else: + is_hex = 0 + if c_name[0] == c'\0': + return 0 + while c_name[0] != c'\0': + if c_name[0] < c'0' or c_name[0] > c'9': + if not is_hex: + return 0 + if not (c'a' <= c_name[0] <= c'f'): + if not (c'A' <= c_name[0] <= c'F'): + return 0 + c_name += 1 + return 1 + +cdef int _tagValidOrRaise(tag_utf) except -1: + if not _pyXmlNameIsValid(tag_utf): + raise ValueError(f"Invalid tag name {(tag_utf).decode('utf8')!r}") + return 0 + +cdef int _htmlTagValidOrRaise(tag_utf) except -1: + if not _pyHtmlNameIsValid(tag_utf): + raise ValueError(f"Invalid HTML tag name {(tag_utf).decode('utf8')!r}") + return 0 + +cdef int _attributeValidOrRaise(name_utf) except -1: + if not _pyXmlNameIsValid(name_utf): + raise ValueError(f"Invalid attribute name {(name_utf).decode('utf8')!r}") + return 0 + +cdef int _prefixValidOrRaise(tag_utf) except -1: + if not _pyXmlNameIsValid(tag_utf): + raise ValueError(f"Invalid namespace prefix {(tag_utf).decode('utf8')!r}") + return 0 + +cdef int _uriValidOrRaise(uri_utf) except -1: + cdef uri.xmlURI* c_uri = uri.xmlParseURI(_cstr(uri_utf)) + if c_uri is NULL: + raise ValueError(f"Invalid namespace URI {(uri_utf).decode('utf8')!r}") + uri.xmlFreeURI(c_uri) + return 0 + +cdef inline unicode _namespacedName(xmlNode* c_node): + return _namespacedNameFromNsName(_getNs(c_node), c_node.name) + + +cdef unicode _namespacedNameFromNsName(const_xmlChar* c_href, const_xmlChar* c_name): + name = funicode(c_name) + if c_href is NULL: + return name + href = funicode(c_href) + return f"{{{href}}}{name}" + + +cdef _getFilenameForFile(source): + """Given a Python File or Gzip object, give filename back. + + Returns None if not a file object. + """ + # urllib2 provides a geturl() method + try: + return source.geturl() + except: + pass + # file instances have a name attribute + try: + filename = source.name + if _isString(filename): + return os_path_abspath(filename) + except: + pass + # gzip file instances have a filename attribute (before Py3k) + try: + filename = source.filename + if _isString(filename): + return os_path_abspath(filename) + except: + pass + # can't determine filename + return None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/builder.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..f5831fb34b7911eb1f420e41ea32484eebff5f85 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/builder.py @@ -0,0 +1,243 @@ +# cython: language_level=2 + +# +# Element generator factory by Fredrik Lundh. +# +# Source: +# http://online.effbot.org/2006_11_01_archive.htm#et-builder +# http://effbot.python-hosting.com/file/stuff/sandbox/elementlib/builder.py +# +# -------------------------------------------------------------------- +# The ElementTree toolkit is +# +# Copyright (c) 1999-2004 by Fredrik Lundh +# +# By obtaining, using, and/or copying this software and/or its +# associated documentation, you agree that you have read, understood, +# and will comply with the following terms and conditions: +# +# Permission to use, copy, modify, and distribute this software and +# its associated documentation for any purpose and without fee is +# hereby granted, provided that the above copyright notice appears in +# all copies, and that both that copyright notice and this permission +# notice appear in supporting documentation, and that the name of +# Secret Labs AB or the author not be used in advertising or publicity +# pertaining to distribution of the software without specific, written +# prior permission. +# +# SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD +# TO THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANT- +# ABILITY AND FITNESS. IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR +# BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY +# DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, +# WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS +# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE +# OF THIS SOFTWARE. +# -------------------------------------------------------------------- + +""" +The ``E`` Element factory for generating XML documents. +""" + + +import lxml.etree as ET +_QName = ET.QName + +from functools import partial + +try: + from types import GenericAlias as _GenericAlias +except ImportError: + # Python 3.8 - we only need this as return value from "__class_getitem__" + def _GenericAlias(cls, item): + return f"{cls.__name__}[{item.__name__}]" + +try: + basestring +except NameError: + basestring = str + +try: + unicode +except NameError: + unicode = str + + +class ElementMaker: + """Element generator factory. + + Unlike the ordinary Element factory, the E factory allows you to pass in + more than just a tag and some optional attributes; you can also pass in + text and other elements. The text is added as either text or tail + attributes, and elements are inserted at the right spot. Some small + examples:: + + >>> from lxml import etree as ET + >>> from lxml.builder import E + + >>> ET.tostring(E("tag")) + '' + >>> ET.tostring(E("tag", "text")) + 'text' + >>> ET.tostring(E("tag", "text", key="value")) + 'text' + >>> ET.tostring(E("tag", E("subtag", "text"), "tail")) + 'texttail' + + For simple tags, the factory also allows you to write ``E.tag(...)`` instead + of ``E('tag', ...)``:: + + >>> ET.tostring(E.tag()) + '' + >>> ET.tostring(E.tag("text")) + 'text' + >>> ET.tostring(E.tag(E.subtag("text"), "tail")) + 'texttail' + + Here's a somewhat larger example; this shows how to generate HTML + documents, using a mix of prepared factory functions for inline elements, + nested ``E.tag`` calls, and embedded XHTML fragments:: + + # some common inline elements + A = E.a + I = E.i + B = E.b + + def CLASS(v): + # helper function, 'class' is a reserved word + return {'class': v} + + page = ( + E.html( + E.head( + E.title("This is a sample document") + ), + E.body( + E.h1("Hello!", CLASS("title")), + E.p("This is a paragraph with ", B("bold"), " text in it!"), + E.p("This is another paragraph, with a ", + A("link", href="http://www.python.org"), "."), + E.p("Here are some reserved characters: ."), + ET.XML("

And finally, here is an embedded XHTML fragment.

"), + ) + ) + ) + + print ET.tostring(page) + + Here's a prettyprinted version of the output from the above script:: + + + + This is a sample document + + +

Hello!

+

This is a paragraph with bold text in it!

+

This is another paragraph, with link.

+

Here are some reserved characters: <spam&egg>.

+

And finally, here is an embedded XHTML fragment.

+ + + + For namespace support, you can pass a namespace map (``nsmap``) + and/or a specific target ``namespace`` to the ElementMaker class:: + + >>> E = ElementMaker(namespace="http://my.ns/") + >>> print(ET.tostring( E.test )) + + + >>> E = ElementMaker(namespace="http://my.ns/", nsmap={'p':'http://my.ns/'}) + >>> print(ET.tostring( E.test )) + + """ + + def __init__(self, typemap=None, + namespace=None, nsmap=None, makeelement=None): + self._namespace = '{' + namespace + '}' if namespace is not None else None + self._nsmap = dict(nsmap) if nsmap else None + + assert makeelement is None or callable(makeelement) + self._makeelement = makeelement if makeelement is not None else ET.Element + + # initialize the default type map functions for this element factory + typemap = dict(typemap) if typemap else {} + + def add_text(elem, item): + try: + last_child = elem[-1] + except IndexError: + elem.text = (elem.text or "") + item + else: + last_child.tail = (last_child.tail or "") + item + + def add_cdata(elem, cdata): + if elem.text: + raise ValueError("Can't add a CDATA section. Element already has some text: %r" % elem.text) + elem.text = cdata + + if str not in typemap: + typemap[str] = add_text + if unicode not in typemap: + typemap[unicode] = add_text + if ET.CDATA not in typemap: + typemap[ET.CDATA] = add_cdata + + def add_dict(elem, item): + attrib = elem.attrib + for k, v in item.items(): + if isinstance(v, basestring): + attrib[k] = v + else: + attrib[k] = typemap[type(v)](None, v) + + if dict not in typemap: + typemap[dict] = add_dict + + self._typemap = typemap + + def __call__(self, tag, *children, **attrib): + typemap = self._typemap + + # We'll usually get a 'str', and the compiled type check is very fast. + if not isinstance(tag, str) and isinstance(tag, _QName): + # A QName is explicitly qualified, do not look at self._namespace. + tag = tag.text + elif self._namespace is not None and tag[0] != '{': + tag = self._namespace + tag + elem = self._makeelement(tag, nsmap=self._nsmap) + if attrib: + typemap[dict](elem, attrib) + + for item in children: + if callable(item): + item = item() + t = typemap.get(type(item)) + if t is None: + if ET.iselement(item): + elem.append(item) + continue + for basetype in type(item).__mro__: + # See if the typemap knows of any of this type's bases. + t = typemap.get(basetype) + if t is not None: + break + else: + raise TypeError("bad argument type: %s(%r)" % + (type(item).__name__, item)) + v = t(elem, item) + if v: + typemap.get(type(v))(elem, v) + + return elem + + def __getattr__(self, tag): + return partial(self, tag) + + # Allow subscripting ElementMaker in type annotions (PEP 560) + def __class_getitem__(cls, item): + return _GenericAlias(cls, item) + + +# create factory object +E = ElementMaker() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/classlookup.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/classlookup.pxi new file mode 100644 index 0000000000000000000000000000000000000000..92d1d47a58657a7741d20f48cfe3525a66dbc722 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/classlookup.pxi @@ -0,0 +1,580 @@ +# Configurable Element class lookup + +################################################################################ +# Custom Element classes + +cdef public class ElementBase(_Element) [ type LxmlElementBaseType, + object LxmlElementBase ]: + """ElementBase(*children, attrib=None, nsmap=None, **_extra) + + The public Element class. All custom Element classes must inherit + from this one. To create an Element, use the `Element()` factory. + + BIG FAT WARNING: Subclasses *must not* override __init__ or + __new__ as it is absolutely undefined when these objects will be + created or destroyed. All persistent state of Elements must be + stored in the underlying XML. If you really need to initialize + the object after creation, you can implement an ``_init(self)`` + method that will be called directly after object creation. + + Subclasses of this class can be instantiated to create a new + Element. By default, the tag name will be the class name and the + namespace will be empty. You can modify this with the following + class attributes: + + * TAG - the tag name, possibly containing a namespace in Clark + notation + + * NAMESPACE - the default namespace URI, unless provided as part + of the TAG attribute. + + * HTML - flag if the class is an HTML tag, as opposed to an XML + tag. This only applies to un-namespaced tags and defaults to + false (i.e. XML). + + * PARSER - the parser that provides the configuration for the + newly created document. Providing an HTML parser here will + default to creating an HTML element. + + In user code, the latter three are commonly inherited in class + hierarchies that implement a common namespace. + """ + def __init__(self, *children, attrib=None, nsmap=None, **_extra): + """ElementBase(*children, attrib=None, nsmap=None, **_extra) + """ + cdef bint is_html = 0 + cdef _BaseParser parser + cdef _Element last_child + # don't use normal attribute access as it might be overridden + _getattr = object.__getattribute__ + try: + namespace = _utf8(_getattr(self, 'NAMESPACE')) + except AttributeError: + namespace = None + try: + ns, tag = _getNsTag(_getattr(self, 'TAG')) + if ns is not None: + namespace = ns + except AttributeError: + tag = _utf8(_getattr(_getattr(self, '__class__'), '__name__')) + if b'.' in tag: + tag = tag.split(b'.')[-1] + try: + parser = _getattr(self, 'PARSER') + except AttributeError: + parser = None + for child in children: + if isinstance(child, _Element): + parser = (<_Element>child)._doc._parser + break + if isinstance(parser, HTMLParser): + is_html = 1 + if namespace is None: + try: + is_html = _getattr(self, 'HTML') + except AttributeError: + pass + _initNewElement(self, is_html, tag, namespace, parser, + attrib, nsmap, _extra) + last_child = None + for child in children: + if _isString(child): + if last_child is None: + _setNodeText(self._c_node, + (_collectText(self._c_node.children) or '') + child) + else: + _setTailText(last_child._c_node, + (_collectText(last_child._c_node.next) or '') + child) + elif isinstance(child, _Element): + last_child = child + _appendChild(self, last_child) + elif isinstance(child, type) and issubclass(child, ElementBase): + last_child = child() + _appendChild(self, last_child) + else: + raise TypeError, f"Invalid child type: {type(child)!r}" + +cdef class CommentBase(_Comment): + """All custom Comment classes must inherit from this one. + + To create an XML Comment instance, use the ``Comment()`` factory. + + Subclasses *must not* override __init__ or __new__ as it is + absolutely undefined when these objects will be created or + destroyed. All persistent state of Comments must be stored in the + underlying XML. If you really need to initialize the object after + creation, you can implement an ``_init(self)`` method that will be + called after object creation. + """ + def __init__(self, text): + # copied from Comment() factory + cdef _Document doc + cdef xmlDoc* c_doc + if text is None: + text = b'' + else: + text = _utf8(text) + c_doc = _newXMLDoc() + doc = _documentFactory(c_doc, None) + self._c_node = _createComment(c_doc, _xcstr(text)) + if self._c_node is NULL: + raise MemoryError() + tree.xmlAddChild(c_doc, self._c_node) + _registerProxy(self, doc, self._c_node) + self._init() + +cdef class PIBase(_ProcessingInstruction): + """All custom Processing Instruction classes must inherit from this one. + + To create an XML ProcessingInstruction instance, use the ``PI()`` + factory. + + Subclasses *must not* override __init__ or __new__ as it is + absolutely undefined when these objects will be created or + destroyed. All persistent state of PIs must be stored in the + underlying XML. If you really need to initialize the object after + creation, you can implement an ``_init(self)`` method that will be + called after object creation. + """ + def __init__(self, target, text=None): + # copied from PI() factory + cdef _Document doc + cdef xmlDoc* c_doc + target = _utf8(target) + if text is None: + text = b'' + else: + text = _utf8(text) + c_doc = _newXMLDoc() + doc = _documentFactory(c_doc, None) + self._c_node = _createPI(c_doc, _xcstr(target), _xcstr(text)) + if self._c_node is NULL: + raise MemoryError() + tree.xmlAddChild(c_doc, self._c_node) + _registerProxy(self, doc, self._c_node) + self._init() + +cdef class EntityBase(_Entity): + """All custom Entity classes must inherit from this one. + + To create an XML Entity instance, use the ``Entity()`` factory. + + Subclasses *must not* override __init__ or __new__ as it is + absolutely undefined when these objects will be created or + destroyed. All persistent state of Entities must be stored in the + underlying XML. If you really need to initialize the object after + creation, you can implement an ``_init(self)`` method that will be + called after object creation. + """ + def __init__(self, name): + cdef _Document doc + cdef xmlDoc* c_doc + name_utf = _utf8(name) + c_name = _xcstr(name_utf) + if c_name[0] == c'#': + if not _characterReferenceIsValid(c_name + 1): + raise ValueError, f"Invalid character reference: '{name}'" + elif not _xmlNameIsValid(c_name): + raise ValueError, f"Invalid entity reference: '{name}'" + c_doc = _newXMLDoc() + doc = _documentFactory(c_doc, None) + self._c_node = _createEntity(c_doc, c_name) + if self._c_node is NULL: + raise MemoryError() + tree.xmlAddChild(c_doc, self._c_node) + _registerProxy(self, doc, self._c_node) + self._init() + + +cdef int _validateNodeClass(xmlNode* c_node, cls) except -1: + if c_node.type == tree.XML_ELEMENT_NODE: + expected = ElementBase + elif c_node.type == tree.XML_COMMENT_NODE: + expected = CommentBase + elif c_node.type == tree.XML_ENTITY_REF_NODE: + expected = EntityBase + elif c_node.type == tree.XML_PI_NODE: + expected = PIBase + else: + assert False, f"Unknown node type: {c_node.type}" + + if not (isinstance(cls, type) and issubclass(cls, expected)): + raise TypeError( + f"result of class lookup must be subclass of {type(expected)}, got {type(cls)}") + return 0 + + +################################################################################ +# Element class lookup + +ctypedef public object (*_element_class_lookup_function)(object, _Document, xmlNode*) + +# class to store element class lookup functions +cdef public class ElementClassLookup [ type LxmlElementClassLookupType, + object LxmlElementClassLookup ]: + """ElementClassLookup(self) + Superclass of Element class lookups. + """ + cdef _element_class_lookup_function _lookup_function + + +cdef public class FallbackElementClassLookup(ElementClassLookup) \ + [ type LxmlFallbackElementClassLookupType, + object LxmlFallbackElementClassLookup ]: + """FallbackElementClassLookup(self, fallback=None) + + Superclass of Element class lookups with additional fallback. + """ + cdef readonly ElementClassLookup fallback + cdef _element_class_lookup_function _fallback_function + def __cinit__(self): + # fall back to default lookup + self._fallback_function = _lookupDefaultElementClass + + def __init__(self, ElementClassLookup fallback=None): + if fallback is not None: + self._setFallback(fallback) + else: + self._fallback_function = _lookupDefaultElementClass + + cdef void _setFallback(self, ElementClassLookup lookup): + """Sets the fallback scheme for this lookup method. + """ + self.fallback = lookup + self._fallback_function = lookup._lookup_function + if self._fallback_function is NULL: + self._fallback_function = _lookupDefaultElementClass + + def set_fallback(self, ElementClassLookup lookup not None): + """set_fallback(self, lookup) + + Sets the fallback scheme for this lookup method. + """ + self._setFallback(lookup) + +cdef inline object _callLookupFallback(FallbackElementClassLookup lookup, + _Document doc, xmlNode* c_node): + return lookup._fallback_function(lookup.fallback, doc, c_node) + + +################################################################################ +# default lookup scheme + +cdef class ElementDefaultClassLookup(ElementClassLookup): + """ElementDefaultClassLookup(self, element=None, comment=None, pi=None, entity=None) + Element class lookup scheme that always returns the default Element + class. + + The keyword arguments ``element``, ``comment``, ``pi`` and ``entity`` + accept the respective Element classes. + """ + cdef readonly object element_class + cdef readonly object comment_class + cdef readonly object pi_class + cdef readonly object entity_class + def __cinit__(self): + self._lookup_function = _lookupDefaultElementClass + + def __init__(self, element=None, comment=None, pi=None, entity=None): + if element is None: + self.element_class = _Element + elif issubclass(element, ElementBase): + self.element_class = element + else: + raise TypeError, "element class must be subclass of ElementBase" + + if comment is None: + self.comment_class = _Comment + elif issubclass(comment, CommentBase): + self.comment_class = comment + else: + raise TypeError, "comment class must be subclass of CommentBase" + + if entity is None: + self.entity_class = _Entity + elif issubclass(entity, EntityBase): + self.entity_class = entity + else: + raise TypeError, "Entity class must be subclass of EntityBase" + + if pi is None: + self.pi_class = None # special case, see below + elif issubclass(pi, PIBase): + self.pi_class = pi + else: + raise TypeError, "PI class must be subclass of PIBase" + +cdef object _lookupDefaultElementClass(state, _Document _doc, xmlNode* c_node): + "Trivial class lookup function that always returns the default class." + if c_node.type == tree.XML_ELEMENT_NODE: + if state is not None: + return (state).element_class + else: + return _Element + elif c_node.type == tree.XML_COMMENT_NODE: + if state is not None: + return (state).comment_class + else: + return _Comment + elif c_node.type == tree.XML_ENTITY_REF_NODE: + if state is not None: + return (state).entity_class + else: + return _Entity + elif c_node.type == tree.XML_PI_NODE: + if state is None or (state).pi_class is None: + # special case XSLT-PI + if c_node.name is not NULL and c_node.content is not NULL: + if tree.xmlStrcmp(c_node.name, "xml-stylesheet") == 0: + if tree.xmlStrstr(c_node.content, "text/xsl") is not NULL or \ + tree.xmlStrstr(c_node.content, "text/xml") is not NULL: + return _XSLTProcessingInstruction + return _ProcessingInstruction + else: + return (state).pi_class + else: + assert False, f"Unknown node type: {c_node.type}" + + +################################################################################ +# attribute based lookup scheme + +cdef class AttributeBasedElementClassLookup(FallbackElementClassLookup): + """AttributeBasedElementClassLookup(self, attribute_name, class_mapping, fallback=None) + Checks an attribute of an Element and looks up the value in a + class dictionary. + + Arguments: + - attribute name - '{ns}name' style string + - class mapping - Python dict mapping attribute values to Element classes + - fallback - optional fallback lookup mechanism + + A None key in the class mapping will be checked if the attribute is + missing. + """ + cdef object _class_mapping + cdef tuple _pytag + cdef const_xmlChar* _c_ns + cdef const_xmlChar* _c_name + def __cinit__(self): + self._lookup_function = _attribute_class_lookup + + def __init__(self, attribute_name, class_mapping, + ElementClassLookup fallback=None): + self._pytag = _getNsTag(attribute_name) + ns, name = self._pytag + if ns is None: + self._c_ns = NULL + else: + self._c_ns = _xcstr(ns) + self._c_name = _xcstr(name) + self._class_mapping = dict(class_mapping) + + FallbackElementClassLookup.__init__(self, fallback) + +cdef object _attribute_class_lookup(state, _Document doc, xmlNode* c_node): + cdef AttributeBasedElementClassLookup lookup + cdef python.PyObject* dict_result + + lookup = state + if c_node.type == tree.XML_ELEMENT_NODE: + value = _attributeValueFromNsName( + c_node, lookup._c_ns, lookup._c_name) + dict_result = python.PyDict_GetItem(lookup._class_mapping, value) + if dict_result is not NULL: + cls = dict_result + _validateNodeClass(c_node, cls) + return cls + return _callLookupFallback(lookup, doc, c_node) + + +################################################################################ +# per-parser lookup scheme + +cdef class ParserBasedElementClassLookup(FallbackElementClassLookup): + """ParserBasedElementClassLookup(self, fallback=None) + Element class lookup based on the XML parser. + """ + def __cinit__(self): + self._lookup_function = _parser_class_lookup + +cdef object _parser_class_lookup(state, _Document doc, xmlNode* c_node): + if doc._parser._class_lookup is not None: + return doc._parser._class_lookup._lookup_function( + doc._parser._class_lookup, doc, c_node) + return _callLookupFallback(state, doc, c_node) + + +################################################################################ +# custom class lookup based on node type, namespace, name + +cdef class CustomElementClassLookup(FallbackElementClassLookup): + """CustomElementClassLookup(self, fallback=None) + Element class lookup based on a subclass method. + + You can inherit from this class and override the method:: + + lookup(self, type, doc, namespace, name) + + to lookup the element class for a node. Arguments of the method: + * type: one of 'element', 'comment', 'PI', 'entity' + * doc: document that the node is in + * namespace: namespace URI of the node (or None for comments/PIs/entities) + * name: name of the element/entity, None for comments, target for PIs + + If you return None from this method, the fallback will be called. + """ + def __cinit__(self): + self._lookup_function = _custom_class_lookup + + def lookup(self, type, doc, namespace, name): + "lookup(self, type, doc, namespace, name)" + return None + +cdef object _custom_class_lookup(state, _Document doc, xmlNode* c_node): + cdef CustomElementClassLookup lookup + + lookup = state + + if c_node.type == tree.XML_ELEMENT_NODE: + element_type = "element" + elif c_node.type == tree.XML_COMMENT_NODE: + element_type = "comment" + elif c_node.type == tree.XML_PI_NODE: + element_type = "PI" + elif c_node.type == tree.XML_ENTITY_REF_NODE: + element_type = "entity" + else: + element_type = "element" + if c_node.name is NULL: + name = None + else: + name = funicode(c_node.name) + c_str = tree._getNs(c_node) + ns = funicode(c_str) if c_str is not NULL else None + + cls = lookup.lookup(element_type, doc, ns, name) + if cls is not None: + _validateNodeClass(c_node, cls) + return cls + return _callLookupFallback(lookup, doc, c_node) + + +################################################################################ +# read-only tree based class lookup + +cdef class PythonElementClassLookup(FallbackElementClassLookup): + """PythonElementClassLookup(self, fallback=None) + Element class lookup based on a subclass method. + + This class lookup scheme allows access to the entire XML tree in + read-only mode. To use it, re-implement the ``lookup(self, doc, + root)`` method in a subclass:: + + from lxml import etree, pyclasslookup + + class MyElementClass(etree.ElementBase): + honkey = True + + class MyLookup(pyclasslookup.PythonElementClassLookup): + def lookup(self, doc, root): + if root.tag == "sometag": + return MyElementClass + else: + for child in root: + if child.tag == "someothertag": + return MyElementClass + # delegate to default + return None + + If you return None from this method, the fallback will be called. + + The first argument is the opaque document instance that contains + the Element. The second argument is a lightweight Element proxy + implementation that is only valid during the lookup. Do not try + to keep a reference to it. Once the lookup is done, the proxy + will be invalid. + + Also, you cannot wrap such a read-only Element in an ElementTree, + and you must take care not to keep a reference to them outside of + the `lookup()` method. + + Note that the API of the Element objects is not complete. It is + purely read-only and does not support all features of the normal + `lxml.etree` API (such as XPath, extended slicing or some + iteration methods). + + See https://lxml.de/element_classes.html + """ + def __cinit__(self): + self._lookup_function = _python_class_lookup + + def lookup(self, doc, element): + """lookup(self, doc, element) + + Override this method to implement your own lookup scheme. + """ + return None + +cdef object _python_class_lookup(state, _Document doc, tree.xmlNode* c_node): + cdef PythonElementClassLookup lookup + cdef _ReadOnlyProxy proxy + lookup = state + + proxy = _newReadOnlyProxy(None, c_node) + cls = lookup.lookup(doc, proxy) + _freeReadOnlyProxies(proxy) + + if cls is not None: + _validateNodeClass(c_node, cls) + return cls + return _callLookupFallback(lookup, doc, c_node) + +################################################################################ +# Global setup + +cdef _element_class_lookup_function LOOKUP_ELEMENT_CLASS +cdef object ELEMENT_CLASS_LOOKUP_STATE + +cdef void _setElementClassLookupFunction( + _element_class_lookup_function function, object state): + global LOOKUP_ELEMENT_CLASS, ELEMENT_CLASS_LOOKUP_STATE + if function is NULL: + state = DEFAULT_ELEMENT_CLASS_LOOKUP + function = DEFAULT_ELEMENT_CLASS_LOOKUP._lookup_function + + ELEMENT_CLASS_LOOKUP_STATE = state + LOOKUP_ELEMENT_CLASS = function + +def set_element_class_lookup(ElementClassLookup lookup = None): + """set_element_class_lookup(lookup = None) + + Set the global element class lookup method. + + This defines the main entry point for looking up element implementations. + The standard implementation uses the :class:`ParserBasedElementClassLookup` + to delegate to different lookup schemes for each parser. + + .. warning:: + + This should only be changed by applications, not by library packages. + In most cases, parser specific lookups should be preferred, + which can be configured via + :meth:`~lxml.etree.XMLParser.set_element_class_lookup` + (and the same for HTML parsers). + + Globally replacing the element class lookup by something other than a + :class:`ParserBasedElementClassLookup` will prevent parser specific lookup + schemes from working. Several tools rely on parser specific lookups, + including :mod:`lxml.html` and :mod:`lxml.objectify`. + """ + if lookup is None or lookup._lookup_function is NULL: + _setElementClassLookupFunction(NULL, None) + else: + _setElementClassLookupFunction(lookup._lookup_function, lookup) + +# default setup: parser delegation +cdef ParserBasedElementClassLookup DEFAULT_ELEMENT_CLASS_LOOKUP +DEFAULT_ELEMENT_CLASS_LOOKUP = ParserBasedElementClassLookup() + +set_element_class_lookup(DEFAULT_ELEMENT_CLASS_LOOKUP) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/cleanup.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/cleanup.pxi new file mode 100644 index 0000000000000000000000000000000000000000..8e266b33f0f3aef34f3448276abfb2cb8b1e4772 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/cleanup.pxi @@ -0,0 +1,215 @@ +# functions for tree cleanup and removing elements from subtrees + +def cleanup_namespaces(tree_or_element, top_nsmap=None, keep_ns_prefixes=None): + """cleanup_namespaces(tree_or_element, top_nsmap=None, keep_ns_prefixes=None) + + Remove all namespace declarations from a subtree that are not used + by any of the elements or attributes in that tree. + + If a 'top_nsmap' is provided, it must be a mapping from prefixes + to namespace URIs. These namespaces will be declared on the top + element of the subtree before running the cleanup, which allows + moving namespace declarations to the top of the tree. + + If a 'keep_ns_prefixes' is provided, it must be a list of prefixes. + These prefixes will not be removed as part of the cleanup. + """ + element = _rootNodeOrRaise(tree_or_element) + c_element = element._c_node + + if top_nsmap: + doc = element._doc + # declare namespaces from nsmap, then apply them to the subtree + _setNodeNamespaces(c_element, doc, None, top_nsmap) + moveNodeToDocument(doc, c_element.doc, c_element) + + keep_ns_prefixes = ( + set([_utf8(prefix) for prefix in keep_ns_prefixes]) + if keep_ns_prefixes else None) + + _removeUnusedNamespaceDeclarations(c_element, keep_ns_prefixes) + + +def strip_attributes(tree_or_element, *attribute_names): + """strip_attributes(tree_or_element, *attribute_names) + + Delete all attributes with the provided attribute names from an + Element (or ElementTree) and its descendants. + + Attribute names can contain wildcards as in `_Element.iter`. + + Example usage:: + + strip_attributes(root_element, + 'simpleattr', + '{http://some/ns}attrname', + '{http://other/ns}*') + """ + cdef _MultiTagMatcher matcher + element = _rootNodeOrRaise(tree_or_element) + if not attribute_names: + return + + matcher = _MultiTagMatcher.__new__(_MultiTagMatcher, attribute_names) + matcher.cacheTags(element._doc) + if matcher.rejectsAllAttributes(): + return + _strip_attributes(element._c_node, matcher) + + +cdef _strip_attributes(xmlNode* c_node, _MultiTagMatcher matcher): + cdef xmlAttr* c_attr + cdef xmlAttr* c_next_attr + tree.BEGIN_FOR_EACH_ELEMENT_FROM(c_node, c_node, 1) + if c_node.type == tree.XML_ELEMENT_NODE: + c_attr = c_node.properties + while c_attr is not NULL: + c_next_attr = c_attr.next + if matcher.matchesAttribute(c_attr): + tree.xmlRemoveProp(c_attr) + c_attr = c_next_attr + tree.END_FOR_EACH_ELEMENT_FROM(c_node) + + +def strip_elements(tree_or_element, *tag_names, bint with_tail=True): + """strip_elements(tree_or_element, *tag_names, with_tail=True) + + Delete all elements with the provided tag names from a tree or + subtree. This will remove the elements and their entire subtree, + including all their attributes, text content and descendants. It + will also remove the tail text of the element unless you + explicitly set the ``with_tail`` keyword argument option to False. + + Tag names can contain wildcards as in `_Element.iter`. + + Note that this will not delete the element (or ElementTree root + element) that you passed even if it matches. It will only treat + its descendants. If you want to include the root element, check + its tag name directly before even calling this function. + + Example usage:: + + strip_elements(some_element, + 'simpletagname', # non-namespaced tag + '{http://some/ns}tagname', # namespaced tag + '{http://some/other/ns}*' # any tag from a namespace + lxml.etree.Comment # comments + ) + """ + cdef _MultiTagMatcher matcher + doc = _documentOrRaise(tree_or_element) + element = _rootNodeOrRaise(tree_or_element) + if not tag_names: + return + + matcher = _MultiTagMatcher.__new__(_MultiTagMatcher, tag_names) + matcher.cacheTags(doc) + if matcher.rejectsAll(): + return + + if isinstance(tree_or_element, _ElementTree): + # include PIs and comments next to the root node + if matcher.matchesType(tree.XML_COMMENT_NODE): + _removeSiblings(element._c_node, tree.XML_COMMENT_NODE, with_tail) + if matcher.matchesType(tree.XML_PI_NODE): + _removeSiblings(element._c_node, tree.XML_PI_NODE, with_tail) + _strip_elements(doc, element._c_node, matcher, with_tail) + +cdef _strip_elements(_Document doc, xmlNode* c_node, _MultiTagMatcher matcher, + bint with_tail): + cdef xmlNode* c_child + cdef xmlNode* c_next + + tree.BEGIN_FOR_EACH_ELEMENT_FROM(c_node, c_node, 1) + if c_node.type == tree.XML_ELEMENT_NODE: + # we run through the children here to prevent any problems + # with the tree iteration which would occur if we unlinked the + # c_node itself + c_child = _findChildForwards(c_node, 0) + while c_child is not NULL: + c_next = _nextElement(c_child) + if matcher.matches(c_child): + if c_child.type == tree.XML_ELEMENT_NODE: + if not with_tail: + tree.xmlUnlinkNode(c_child) + _removeNode(doc, c_child) + else: + if with_tail: + _removeText(c_child.next) + tree.xmlUnlinkNode(c_child) + attemptDeallocation(c_child) + c_child = c_next + tree.END_FOR_EACH_ELEMENT_FROM(c_node) + + +def strip_tags(tree_or_element, *tag_names): + """strip_tags(tree_or_element, *tag_names) + + Delete all elements with the provided tag names from a tree or + subtree. This will remove the elements and their attributes, but + *not* their text/tail content or descendants. Instead, it will + merge the text content and children of the element into its + parent. + + Tag names can contain wildcards as in `_Element.iter`. + + Note that this will not delete the element (or ElementTree root + element) that you passed even if it matches. It will only treat + its descendants. + + Example usage:: + + strip_tags(some_element, + 'simpletagname', # non-namespaced tag + '{http://some/ns}tagname', # namespaced tag + '{http://some/other/ns}*' # any tag from a namespace + Comment # comments (including their text!) + ) + """ + cdef _MultiTagMatcher matcher + doc = _documentOrRaise(tree_or_element) + element = _rootNodeOrRaise(tree_or_element) + if not tag_names: + return + + matcher = _MultiTagMatcher.__new__(_MultiTagMatcher, tag_names) + matcher.cacheTags(doc) + if matcher.rejectsAll(): + return + + if isinstance(tree_or_element, _ElementTree): + # include PIs and comments next to the root node + if matcher.matchesType(tree.XML_COMMENT_NODE): + _removeSiblings(element._c_node, tree.XML_COMMENT_NODE, 0) + if matcher.matchesType(tree.XML_PI_NODE): + _removeSiblings(element._c_node, tree.XML_PI_NODE, 0) + _strip_tags(doc, element._c_node, matcher) + +cdef _strip_tags(_Document doc, xmlNode* c_node, _MultiTagMatcher matcher): + cdef xmlNode* c_child + cdef xmlNode* c_next + + tree.BEGIN_FOR_EACH_ELEMENT_FROM(c_node, c_node, 1) + if c_node.type == tree.XML_ELEMENT_NODE: + # we run through the children here to prevent any problems + # with the tree iteration which would occur if we unlinked the + # c_node itself + c_child = _findChildForwards(c_node, 0) + while c_child is not NULL: + if not matcher.matches(c_child): + c_child = _nextElement(c_child) + continue + if c_child.type == tree.XML_ELEMENT_NODE: + c_next = _findChildForwards(c_child, 0) or _nextElement(c_child) + _replaceNodeByChildren(doc, c_child) + if not attemptDeallocation(c_child): + if c_child.nsDef is not NULL: + # make namespaces absolute + moveNodeToDocument(doc, doc._c_doc, c_child) + c_child = c_next + else: + c_next = _nextElement(c_child) + tree.xmlUnlinkNode(c_child) + attemptDeallocation(c_child) + c_child = c_next + tree.END_FOR_EACH_ELEMENT_FROM(c_node) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/cssselect.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/cssselect.py new file mode 100644 index 0000000000000000000000000000000000000000..54cd75ac9bfecdec7ea81e91b0840c6edd401515 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/cssselect.py @@ -0,0 +1,101 @@ +"""CSS Selectors based on XPath. + +This module supports selecting XML/HTML tags based on CSS selectors. +See the `CSSSelector` class for details. + +This is a thin wrapper around cssselect 0.7 or later. +""" + + +from . import etree +try: + import cssselect as external_cssselect +except ImportError: + raise ImportError( + 'cssselect does not seem to be installed. ' + 'See https://pypi.org/project/cssselect/') + + +SelectorSyntaxError = external_cssselect.SelectorSyntaxError +ExpressionError = external_cssselect.ExpressionError +SelectorError = external_cssselect.SelectorError + + +__all__ = ['SelectorSyntaxError', 'ExpressionError', 'SelectorError', + 'CSSSelector'] + + +class LxmlTranslator(external_cssselect.GenericTranslator): + """ + A custom CSS selector to XPath translator with lxml-specific extensions. + """ + def xpath_contains_function(self, xpath, function): + # Defined there, removed in later drafts: + # http://www.w3.org/TR/2001/CR-css3-selectors-20011113/#content-selectors + if function.argument_types() not in (['STRING'], ['IDENT']): + raise ExpressionError( + "Expected a single string or ident for :contains(), got %r" + % function.arguments) + value = function.arguments[0].value + return xpath.add_condition( + 'contains(__lxml_internal_css:lower-case(string(.)), %s)' + % self.xpath_literal(value.lower())) + + +class LxmlHTMLTranslator(LxmlTranslator, external_cssselect.HTMLTranslator): + """ + lxml extensions + HTML support. + """ + + +def _make_lower_case(context, s): + return s.lower() + +ns = etree.FunctionNamespace('http://codespeak.net/lxml/css/') +ns.prefix = '__lxml_internal_css' +ns['lower-case'] = _make_lower_case + + +class CSSSelector(etree.XPath): + """A CSS selector. + + Usage:: + + >>> from lxml import etree, cssselect + >>> select = cssselect.CSSSelector("a tag > child") + + >>> root = etree.XML("TEXT") + >>> [ el.tag for el in select(root) ] + ['child'] + + To use CSS namespaces, you need to pass a prefix-to-namespace + mapping as ``namespaces`` keyword argument:: + + >>> rdfns = 'http://www.w3.org/1999/02/22-rdf-syntax-ns#' + >>> select_ns = cssselect.CSSSelector('root > rdf|Description', + ... namespaces={'rdf': rdfns}) + + >>> rdf = etree.XML(( + ... '' + ... 'blah' + ... '') % rdfns) + >>> [(el.tag, el.text) for el in select_ns(rdf)] + [('{http://www.w3.org/1999/02/22-rdf-syntax-ns#}Description', 'blah')] + + """ + def __init__(self, css, namespaces=None, translator='xml'): + if translator == 'xml': + translator = LxmlTranslator() + elif translator == 'html': + translator = LxmlHTMLTranslator() + elif translator == 'xhtml': + translator = LxmlHTMLTranslator(xhtml=True) + path = translator.css_to_xpath(css) + super().__init__(path, namespaces=namespaces) + self.css = css + + def __repr__(self): + return '<%s %x for %r>' % ( + self.__class__.__name__, + abs(id(self)), + self.css) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/doctestcompare.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/doctestcompare.py new file mode 100644 index 0000000000000000000000000000000000000000..8099771de906a37ed007c779f152fe96f182060d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/doctestcompare.py @@ -0,0 +1,488 @@ +""" +lxml-based doctest output comparison. + +Note: normally, you should just import the `lxml.usedoctest` and +`lxml.html.usedoctest` modules from within a doctest, instead of this +one:: + + >>> import lxml.usedoctest # for XML output + + >>> import lxml.html.usedoctest # for HTML output + +To use this module directly, you must call ``lxmldoctest.install()``, +which will cause doctest to use this in all subsequent calls. + +This changes the way output is checked and comparisons are made for +XML or HTML-like content. + +XML or HTML content is noticed because the example starts with ``<`` +(it's HTML if it starts with ```` or include an ``any`` +attribute in the tag. An ``any`` tag matches any tag, while the +attribute matches any and all attributes. + +When a match fails, the reformatted example and gotten text is +displayed (indented), and a rough diff-like output is given. Anything +marked with ``+`` is in the output but wasn't supposed to be, and +similarly ``-`` means its in the example but wasn't in the output. + +You can disable parsing on one line with ``# doctest:+NOPARSE_MARKUP`` +""" + +from lxml import etree +import sys +import re +import doctest +try: + from html import escape as html_escape +except ImportError: + from cgi import escape as html_escape + +__all__ = ['PARSE_HTML', 'PARSE_XML', 'NOPARSE_MARKUP', 'LXMLOutputChecker', + 'LHTMLOutputChecker', 'install', 'temp_install'] + +PARSE_HTML = doctest.register_optionflag('PARSE_HTML') +PARSE_XML = doctest.register_optionflag('PARSE_XML') +NOPARSE_MARKUP = doctest.register_optionflag('NOPARSE_MARKUP') + +OutputChecker = doctest.OutputChecker + +def strip(v): + if v is None: + return None + else: + return v.strip() + +def norm_whitespace(v): + return _norm_whitespace_re.sub(' ', v) + +_html_parser = etree.HTMLParser(recover=False, remove_blank_text=True) + +def html_fromstring(html): + return etree.fromstring(html, _html_parser) + +# We use this to distinguish repr()s from elements: +_repr_re = re.compile(r'^<[^>]+ (at|object) ') +_norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+') + +class LXMLOutputChecker(OutputChecker): + + empty_tags = ( + 'param', 'img', 'area', 'br', 'basefont', 'input', + 'base', 'meta', 'link', 'col') + + def get_default_parser(self): + return etree.XML + + def check_output(self, want, got, optionflags): + alt_self = getattr(self, '_temp_override_self', None) + if alt_self is not None: + super_method = self._temp_call_super_check_output + self = alt_self + else: + super_method = OutputChecker.check_output + parser = self.get_parser(want, got, optionflags) + if not parser: + return super_method( + self, want, got, optionflags) + try: + want_doc = parser(want) + except etree.XMLSyntaxError: + return False + try: + got_doc = parser(got) + except etree.XMLSyntaxError: + return False + return self.compare_docs(want_doc, got_doc) + + def get_parser(self, want, got, optionflags): + parser = None + if NOPARSE_MARKUP & optionflags: + return None + if PARSE_HTML & optionflags: + parser = html_fromstring + elif PARSE_XML & optionflags: + parser = etree.XML + elif (want.strip().lower().startswith('' % el.tag + return '<%s %s>' % (el.tag, ' '.join(attrs)) + + def format_end_tag(self, el): + if isinstance(el, etree.CommentBase): + # FIXME: probably PIs should be handled specially too? + return '-->' + return '' % el.tag + + def collect_diff(self, want, got, html, indent): + parts = [] + if not len(want) and not len(got): + parts.append(' '*indent) + parts.append(self.collect_diff_tag(want, got)) + if not self.html_empty_tag(got, html): + parts.append(self.collect_diff_text(want.text, got.text)) + parts.append(self.collect_diff_end_tag(want, got)) + parts.append(self.collect_diff_text(want.tail, got.tail)) + parts.append('\n') + return ''.join(parts) + parts.append(' '*indent) + parts.append(self.collect_diff_tag(want, got)) + parts.append('\n') + if strip(want.text) or strip(got.text): + parts.append(' '*indent) + parts.append(self.collect_diff_text(want.text, got.text)) + parts.append('\n') + want_children = list(want) + got_children = list(got) + while want_children or got_children: + if not want_children: + parts.append(self.format_doc(got_children.pop(0), html, indent+2, '+')) + continue + if not got_children: + parts.append(self.format_doc(want_children.pop(0), html, indent+2, '-')) + continue + parts.append(self.collect_diff( + want_children.pop(0), got_children.pop(0), html, indent+2)) + parts.append(' '*indent) + parts.append(self.collect_diff_end_tag(want, got)) + parts.append('\n') + if strip(want.tail) or strip(got.tail): + parts.append(' '*indent) + parts.append(self.collect_diff_text(want.tail, got.tail)) + parts.append('\n') + return ''.join(parts) + + def collect_diff_tag(self, want, got): + if not self.tag_compare(want.tag, got.tag): + tag = '%s (got: %s)' % (want.tag, got.tag) + else: + tag = got.tag + attrs = [] + any = want.tag == 'any' or 'any' in want.attrib + for name, value in sorted(got.attrib.items()): + if name not in want.attrib and not any: + attrs.append('+%s="%s"' % (name, self.format_text(value, False))) + else: + if name in want.attrib: + text = self.collect_diff_text(want.attrib[name], value, False) + else: + text = self.format_text(value, False) + attrs.append('%s="%s"' % (name, text)) + if not any: + for name, value in sorted(want.attrib.items()): + if name in got.attrib: + continue + attrs.append('-%s="%s"' % (name, self.format_text(value, False))) + if attrs: + tag = '<%s %s>' % (tag, ' '.join(attrs)) + else: + tag = '<%s>' % tag + return tag + + def collect_diff_end_tag(self, want, got): + if want.tag != got.tag: + tag = '%s (got: %s)' % (want.tag, got.tag) + else: + tag = got.tag + return '' % tag + + def collect_diff_text(self, want, got, strip=True): + if self.text_compare(want, got, strip): + if not got: + return '' + return self.format_text(got, strip) + text = '%s (got: %s)' % (want, got) + return self.format_text(text, strip) + +class LHTMLOutputChecker(LXMLOutputChecker): + def get_default_parser(self): + return html_fromstring + +def install(html=False): + """ + Install doctestcompare for all future doctests. + + If html is true, then by default the HTML parser will be used; + otherwise the XML parser is used. + """ + if html: + doctest.OutputChecker = LHTMLOutputChecker + else: + doctest.OutputChecker = LXMLOutputChecker + +def temp_install(html=False, del_module=None): + """ + Use this *inside* a doctest to enable this checker for this + doctest only. + + If html is true, then by default the HTML parser will be used; + otherwise the XML parser is used. + """ + if html: + Checker = LHTMLOutputChecker + else: + Checker = LXMLOutputChecker + frame = _find_doctest_frame() + dt_self = frame.f_locals['self'] + checker = Checker() + old_checker = dt_self._checker + dt_self._checker = checker + # The unfortunate thing is that there is a local variable 'check' + # in the function that runs the doctests, that is a bound method + # into the output checker. We have to update that. We can't + # modify the frame, so we have to modify the object in place. The + # only way to do this is to actually change the func_code + # attribute of the method. We change it, and then wait for + # __record_outcome to be run, which signals the end of the __run + # method, at which point we restore the previous check_output + # implementation. + check_func = frame.f_locals['check'].__func__ + checker_check_func = checker.check_output.__func__ + # Because we can't patch up func_globals, this is the only global + # in check_output that we care about: + doctest.etree = etree + _RestoreChecker(dt_self, old_checker, checker, + check_func, checker_check_func, + del_module) + +class _RestoreChecker: + def __init__(self, dt_self, old_checker, new_checker, check_func, clone_func, + del_module): + self.dt_self = dt_self + self.checker = old_checker + self.checker._temp_call_super_check_output = self.call_super + self.checker._temp_override_self = new_checker + self.check_func = check_func + self.clone_func = clone_func + self.del_module = del_module + self.install_clone() + self.install_dt_self() + def install_clone(self): + self.func_code = self.check_func.__code__ + self.func_globals = self.check_func.__globals__ + self.check_func.__code__ = self.clone_func.__code__ + def uninstall_clone(self): + self.check_func.__code__ = self.func_code + def install_dt_self(self): + self.prev_func = self.dt_self._DocTestRunner__record_outcome + self.dt_self._DocTestRunner__record_outcome = self + def uninstall_dt_self(self): + self.dt_self._DocTestRunner__record_outcome = self.prev_func + def uninstall_module(self): + if self.del_module: + import sys + del sys.modules[self.del_module] + if '.' in self.del_module: + package, module = self.del_module.rsplit('.', 1) + package_mod = sys.modules[package] + delattr(package_mod, module) + def __call__(self, *args, **kw): + self.uninstall_clone() + self.uninstall_dt_self() + del self.checker._temp_override_self + del self.checker._temp_call_super_check_output + result = self.prev_func(*args, **kw) + self.uninstall_module() + return result + def call_super(self, *args, **kw): + self.uninstall_clone() + try: + return self.check_func(*args, **kw) + finally: + self.install_clone() + +def _find_doctest_frame(): + import sys + frame = sys._getframe(1) + while frame: + l = frame.f_locals + if 'BOOM' in l: + # Sign of doctest + return frame + frame = frame.f_back + raise LookupError( + "Could not find doctest (only use this function *inside* a doctest)") + +__test__ = { + 'basic': ''' + >>> temp_install() + >>> print """stuff""" + ... + >>> print """""" + + + + >>> print """blahblahblah""" # doctest: +NOPARSE_MARKUP, +ELLIPSIS + ...foo /> + '''} + +if __name__ == '__main__': + import doctest + doctest.testmod() + + diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/etree.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/etree.h new file mode 100644 index 0000000000000000000000000000000000000000..17b99a7be5c4159429d575c9e98f621f57c8310c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/etree.h @@ -0,0 +1,244 @@ +/* Generated by Cython 3.1.4 */ + +#ifndef __PYX_HAVE__lxml__etree +#define __PYX_HAVE__lxml__etree + +#include "Python.h" +struct LxmlDocument; +struct LxmlElement; +struct LxmlElementTree; +struct LxmlElementTagMatcher; +struct LxmlElementIterator; +struct LxmlElementBase; +struct LxmlElementClassLookup; +struct LxmlFallbackElementClassLookup; + +/* "lxml/etree.pyx":451 + * + * # type of a function that steps from node to node + * ctypedef public xmlNode* (*_node_to_node_function)(xmlNode*) # <<<<<<<<<<<<<< + * + * +*/ +typedef xmlNode *(*_node_to_node_function)(xmlNode *); + +/* "lxml/etree.pyx":465 + * # Public Python API + * + * @cython.final # <<<<<<<<<<<<<< + * @cython.freelist(8) + * cdef public class _Document [ type LxmlDocumentType, object LxmlDocument ]: +*/ +struct LxmlDocument { + PyObject_HEAD + struct __pyx_vtabstruct_4lxml_5etree__Document *__pyx_vtab; + int _ns_counter; + PyObject *_prefix_tail; + xmlDoc *_c_doc; + struct __pyx_obj_4lxml_5etree__BaseParser *_parser; +}; + +/* "lxml/etree.pyx":817 + * + * + * @cython.no_gc_clear # <<<<<<<<<<<<<< + * cdef public class _Element [ type LxmlElementType, object LxmlElement ]: + * """Element class. +*/ +struct LxmlElement { + PyObject_HEAD + struct LxmlDocument *_doc; + xmlNode *_c_node; + PyObject *_tag; +}; + +/* "lxml/etree.pyx":1991 + * + * + * cdef public class _ElementTree [ type LxmlElementTreeType, # <<<<<<<<<<<<<< + * object LxmlElementTree ]: + * cdef _Document _doc +*/ +struct LxmlElementTree { + PyObject_HEAD + struct __pyx_vtabstruct_4lxml_5etree__ElementTree *__pyx_vtab; + struct LxmlDocument *_doc; + struct LxmlElement *_context_node; +}; + +/* "lxml/etree.pyx":2765 + * + * + * cdef public class _ElementTagMatcher [ object LxmlElementTagMatcher, # <<<<<<<<<<<<<< + * type LxmlElementTagMatcherType ]: + * """ +*/ +struct LxmlElementTagMatcher { + PyObject_HEAD + struct __pyx_vtabstruct_4lxml_5etree__ElementTagMatcher *__pyx_vtab; + PyObject *_pystrings; + int _node_type; + char *_href; + char *_name; +}; + +/* "lxml/etree.pyx":2796 + * self._name = NULL + * + * cdef public class _ElementIterator(_ElementTagMatcher) [ # <<<<<<<<<<<<<< + * object LxmlElementIterator, type LxmlElementIteratorType ]: + * """ +*/ +struct LxmlElementIterator { + struct LxmlElementTagMatcher __pyx_base; + struct LxmlElement *_node; + _node_to_node_function _next_element; +}; + +/* "src/lxml/classlookup.pxi":6 + * # Custom Element classes + * + * cdef public class ElementBase(_Element) [ type LxmlElementBaseType, # <<<<<<<<<<<<<< + * object LxmlElementBase ]: + * """ElementBase(*children, attrib=None, nsmap=None, **_extra) +*/ +struct LxmlElementBase { + struct LxmlElement __pyx_base; +}; + +/* "src/lxml/classlookup.pxi":210 + * # Element class lookup + * + * ctypedef public object (*_element_class_lookup_function)(object, _Document, xmlNode*) # <<<<<<<<<<<<<< + * + * # class to store element class lookup functions +*/ +typedef PyObject *(*_element_class_lookup_function)(PyObject *, struct LxmlDocument *, xmlNode *); + +/* "src/lxml/classlookup.pxi":213 + * + * # class to store element class lookup functions + * cdef public class ElementClassLookup [ type LxmlElementClassLookupType, # <<<<<<<<<<<<<< + * object LxmlElementClassLookup ]: + * """ElementClassLookup(self) +*/ +struct LxmlElementClassLookup { + PyObject_HEAD + _element_class_lookup_function _lookup_function; +}; + +/* "src/lxml/classlookup.pxi":221 + * + * + * cdef public class FallbackElementClassLookup(ElementClassLookup) \ # <<<<<<<<<<<<<< + * [ type LxmlFallbackElementClassLookupType, + * object LxmlFallbackElementClassLookup ]: +*/ +struct LxmlFallbackElementClassLookup { + struct LxmlElementClassLookup __pyx_base; + struct __pyx_vtabstruct_4lxml_5etree_FallbackElementClassLookup *__pyx_vtab; + struct LxmlElementClassLookup *fallback; + _element_class_lookup_function _fallback_function; +}; + +#ifndef __PYX_HAVE_API__lxml__etree + +#ifdef CYTHON_EXTERN_C + #undef __PYX_EXTERN_C + #define __PYX_EXTERN_C CYTHON_EXTERN_C +#elif defined(__PYX_EXTERN_C) + #ifdef _MSC_VER + #pragma message ("Please do not define the '__PYX_EXTERN_C' macro externally. Use 'CYTHON_EXTERN_C' instead.") + #else + #warning Please do not define the '__PYX_EXTERN_C' macro externally. Use 'CYTHON_EXTERN_C' instead. + #endif +#else + #ifdef __cplusplus + #define __PYX_EXTERN_C extern "C" + #else + #define __PYX_EXTERN_C extern + #endif +#endif + +#ifndef DL_IMPORT + #define DL_IMPORT(_T) _T +#endif + +__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlDocumentType; +__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementType; +__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementTreeType; +__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementTagMatcherType; +__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementIteratorType; +__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementBaseType; +__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementClassLookupType; +__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlFallbackElementClassLookupType; + +__PYX_EXTERN_C struct LxmlElement *deepcopyNodeToDocument(struct LxmlDocument *, xmlNode *); +__PYX_EXTERN_C struct LxmlElementTree *elementTreeFactory(struct LxmlElement *); +__PYX_EXTERN_C struct LxmlElementTree *newElementTree(struct LxmlElement *, PyObject *); +__PYX_EXTERN_C struct LxmlElementTree *adoptExternalDocument(xmlDoc *, PyObject *, int); +__PYX_EXTERN_C struct LxmlElement *elementFactory(struct LxmlDocument *, xmlNode *); +__PYX_EXTERN_C struct LxmlElement *makeElement(PyObject *, struct LxmlDocument *, PyObject *, PyObject *, PyObject *, PyObject *, PyObject *); +__PYX_EXTERN_C struct LxmlElement *makeSubElement(struct LxmlElement *, PyObject *, PyObject *, PyObject *, PyObject *, PyObject *); +__PYX_EXTERN_C void setElementClassLookupFunction(_element_class_lookup_function, PyObject *); +__PYX_EXTERN_C PyObject *lookupDefaultElementClass(PyObject *, PyObject *, xmlNode *); +__PYX_EXTERN_C PyObject *lookupNamespaceElementClass(PyObject *, PyObject *, xmlNode *); +__PYX_EXTERN_C PyObject *callLookupFallback(struct LxmlFallbackElementClassLookup *, struct LxmlDocument *, xmlNode *); +__PYX_EXTERN_C int tagMatches(xmlNode *, const xmlChar *, const xmlChar *); +__PYX_EXTERN_C struct LxmlDocument *documentOrRaise(PyObject *); +__PYX_EXTERN_C struct LxmlElement *rootNodeOrRaise(PyObject *); +__PYX_EXTERN_C int hasText(xmlNode *); +__PYX_EXTERN_C int hasTail(xmlNode *); +__PYX_EXTERN_C PyObject *textOf(xmlNode *); +__PYX_EXTERN_C PyObject *tailOf(xmlNode *); +__PYX_EXTERN_C int setNodeText(xmlNode *, PyObject *); +__PYX_EXTERN_C int setTailText(xmlNode *, PyObject *); +__PYX_EXTERN_C PyObject *attributeValue(xmlNode *, xmlAttr *); +__PYX_EXTERN_C PyObject *attributeValueFromNsName(xmlNode *, const xmlChar *, const xmlChar *); +__PYX_EXTERN_C PyObject *getAttributeValue(struct LxmlElement *, PyObject *, PyObject *); +__PYX_EXTERN_C PyObject *iterattributes(struct LxmlElement *, int); +__PYX_EXTERN_C PyObject *collectAttributes(xmlNode *, int); +__PYX_EXTERN_C int setAttributeValue(struct LxmlElement *, PyObject *, PyObject *); +__PYX_EXTERN_C int delAttribute(struct LxmlElement *, PyObject *); +__PYX_EXTERN_C int delAttributeFromNsName(xmlNode *, const xmlChar *, const xmlChar *); +__PYX_EXTERN_C int hasChild(xmlNode *); +__PYX_EXTERN_C xmlNode *findChild(xmlNode *, Py_ssize_t); +__PYX_EXTERN_C xmlNode *findChildForwards(xmlNode *, Py_ssize_t); +__PYX_EXTERN_C xmlNode *findChildBackwards(xmlNode *, Py_ssize_t); +__PYX_EXTERN_C xmlNode *nextElement(xmlNode *); +__PYX_EXTERN_C xmlNode *previousElement(xmlNode *); +__PYX_EXTERN_C void appendChild(struct LxmlElement *, struct LxmlElement *); +__PYX_EXTERN_C int appendChildToElement(struct LxmlElement *, struct LxmlElement *); +__PYX_EXTERN_C PyObject *pyunicode(const xmlChar *); +__PYX_EXTERN_C PyObject *utf8(PyObject *); +__PYX_EXTERN_C PyObject *getNsTag(PyObject *); +__PYX_EXTERN_C PyObject *getNsTagWithEmptyNs(PyObject *); +__PYX_EXTERN_C PyObject *namespacedName(xmlNode *); +__PYX_EXTERN_C PyObject *namespacedNameFromNsName(const xmlChar *, const xmlChar *); +__PYX_EXTERN_C void iteratorStoreNext(struct LxmlElementIterator *, struct LxmlElement *); +__PYX_EXTERN_C void initTagMatch(struct LxmlElementTagMatcher *, PyObject *); +__PYX_EXTERN_C xmlNs *findOrBuildNodeNsPrefix(struct LxmlDocument *, xmlNode *, const xmlChar *, const xmlChar *); + +#endif /* !__PYX_HAVE_API__lxml__etree */ + +/* WARNING: the interface of the module init function changed in CPython 3.5. */ +/* It now returns a PyModuleDef instance instead of a PyModule instance. */ + +/* WARNING: Use PyImport_AppendInittab("etree", PyInit_etree) instead of calling PyInit_etree directly from Python 3.5 */ +PyMODINIT_FUNC PyInit_etree(void); + +#if PY_VERSION_HEX >= 0x03050000 && (defined(__GNUC__) || defined(__clang__) || defined(_MSC_VER) || (defined(__cplusplus) && __cplusplus >= 201402L)) +#if defined(__cplusplus) && __cplusplus >= 201402L +[[deprecated("Use PyImport_AppendInittab(\"etree\", PyInit_etree) instead of calling PyInit_etree directly.")]] inline +#elif defined(__GNUC__) || defined(__clang__) +__attribute__ ((__deprecated__("Use PyImport_AppendInittab(\"etree\", PyInit_etree) instead of calling PyInit_etree directly."), __unused__)) __inline__ +#elif defined(_MSC_VER) +__declspec(deprecated("Use PyImport_AppendInittab(\"etree\", PyInit_etree) instead of calling PyInit_etree directly.")) __inline +#endif +static PyObject* __PYX_WARN_IF_PyInit_etree_INIT_CALLED(PyObject* res) { + return res; +} +#define PyInit_etree() __PYX_WARN_IF_PyInit_etree_INIT_CALLED(PyInit_etree()) +#endif + +#endif /* !__PYX_HAVE__lxml__etree */ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/etree.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/etree.pyx new file mode 100644 index 0000000000000000000000000000000000000000..562d95ed167945504fd50182824340a5931c4b10 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/etree.pyx @@ -0,0 +1,3853 @@ +# cython: binding=True +# cython: auto_pickle=False +# cython: language_level=3 + +""" +The ``lxml.etree`` module implements the extended ElementTree API for XML. +""" + +__docformat__ = "restructuredtext en" + +__all__ = [ + 'AttributeBasedElementClassLookup', 'C14NError', 'C14NWriterTarget', 'CDATA', + 'Comment', 'CommentBase', 'CustomElementClassLookup', 'DEBUG', + 'DTD', 'DTDError', 'DTDParseError', 'DTDValidateError', + 'DocumentInvalid', 'ETCompatXMLParser', 'ETXPath', 'Element', + 'ElementBase', 'ElementClassLookup', 'ElementDefaultClassLookup', + 'ElementNamespaceClassLookup', 'ElementTree', 'Entity', 'EntityBase', + 'Error', 'ErrorDomains', 'ErrorLevels', 'ErrorTypes', 'Extension', + 'FallbackElementClassLookup', 'FunctionNamespace', 'HTML', 'HTMLParser', + 'ICONV_COMPILED_VERSION', + 'LIBXML_COMPILED_VERSION', 'LIBXML_VERSION', + 'LIBXML_FEATURES', + 'LIBXSLT_COMPILED_VERSION', 'LIBXSLT_VERSION', + 'LXML_VERSION', + 'LxmlError', 'LxmlRegistryError', 'LxmlSyntaxError', + 'NamespaceRegistryError', 'PI', 'PIBase', 'ParseError', + 'ParserBasedElementClassLookup', 'ParserError', 'ProcessingInstruction', + 'PyErrorLog', 'PythonElementClassLookup', 'QName', 'RelaxNG', + 'RelaxNGError', 'RelaxNGErrorTypes', 'RelaxNGParseError', + 'RelaxNGValidateError', 'Resolver', 'Schematron', 'SchematronError', + 'SchematronParseError', 'SchematronValidateError', 'SerialisationError', + 'SubElement', 'TreeBuilder', 'XInclude', 'XIncludeError', 'XML', + 'XMLDTDID', 'XMLID', 'XMLParser', 'XMLSchema', 'XMLSchemaError', + 'XMLSchemaParseError', 'XMLSchemaValidateError', 'XMLSyntaxError', + 'XMLTreeBuilder', 'XPath', 'XPathDocumentEvaluator', 'XPathError', + 'XPathEvalError', 'XPathEvaluator', 'XPathFunctionError', 'XPathResultError', + 'XPathSyntaxError', 'XSLT', 'XSLTAccessControl', 'XSLTApplyError', + 'XSLTError', 'XSLTExtension', 'XSLTExtensionError', 'XSLTParseError', + 'XSLTSaveError', 'canonicalize', + 'cleanup_namespaces', 'clear_error_log', 'dump', + 'fromstring', 'fromstringlist', 'get_default_parser', 'iselement', + 'iterparse', 'iterwalk', 'parse', 'parseid', 'register_namespace', + 'set_default_parser', 'set_element_class_lookup', 'strip_attributes', + 'strip_elements', 'strip_tags', 'tostring', 'tostringlist', 'tounicode', + 'use_global_python_log' + ] + +cimport cython + +from lxml cimport python +from lxml.includes cimport tree, config +from lxml.includes.tree cimport xmlDoc, xmlNode, xmlAttr, xmlNs, _isElement, _getNs +from lxml.includes.tree cimport const_xmlChar, xmlChar, _xcstr +from lxml.python cimport _cstr, _isString +from lxml.includes cimport xpath +from lxml.includes cimport c14n + +# Cython's standard declarations +cimport cpython.mem +cimport cpython.ref +from libc cimport limits, stdio, stdlib +from libc cimport string as cstring_h # not to be confused with stdlib 'string' +from libc.string cimport const_char + +cdef object os_path_abspath +from os.path import abspath as os_path_abspath + +cdef object BytesIO, StringIO +from io import BytesIO, StringIO + +cdef object OrderedDict +from collections import OrderedDict + +cdef object _elementpath +from lxml import _elementpath + +cdef object sys +import sys + +cdef object re +import re + +cdef object partial +from functools import partial + +cdef object islice +from itertools import islice + +cdef object ITER_EMPTY = iter(()) + +cdef object MutableMapping +from collections.abc import MutableMapping + +class _ImmutableMapping(MutableMapping): + def __getitem__(self, key): + raise KeyError, key + + def __setitem__(self, key, value): + raise KeyError, key + + def __delitem__(self, key): + raise KeyError, key + + def __contains__(self, key): + return False + + def __len__(self): + return 0 + + def __iter__(self): + return ITER_EMPTY + iterkeys = itervalues = iteritems = __iter__ + +cdef object IMMUTABLE_EMPTY_MAPPING = _ImmutableMapping() +del _ImmutableMapping + + +# the rules +# --------- +# any libxml C argument/variable is prefixed with c_ +# any non-public function/class is prefixed with an underscore +# instance creation is always through factories + +# what to do with libxml2/libxslt error messages? +# 0 : drop +# 1 : use log +DEF __DEBUG = 1 + +# maximum number of lines in the libxml2/xslt log if __DEBUG == 1 +DEF __MAX_LOG_SIZE = 100 + +# make the compiled-in debug state publicly available +DEBUG = __DEBUG + +# A struct to store a cached qualified tag name+href pair. +# While we can borrow the c_name from the document dict, +# PyPy requires us to store a Python reference for the +# namespace in order to keep the byte buffer alive. +cdef struct qname: + const_xmlChar* c_name + python.PyObject* href + +# initialize parser (and threading) +xmlparser.xmlInitParser() + +# global per-thread setup +tree.xmlThrDefIndentTreeOutput(1) +tree.xmlThrDefLineNumbersDefaultValue(1) + +_initThreadLogging() + +# filename encoding +cdef bytes _FILENAME_ENCODING = (sys.getfilesystemencoding() or sys.getdefaultencoding() or 'ascii').encode("UTF-8") +cdef char* _C_FILENAME_ENCODING = _cstr(_FILENAME_ENCODING) + +# set up some default namespace prefixes +cdef dict _DEFAULT_NAMESPACE_PREFIXES = { + b"http://www.w3.org/XML/1998/namespace": b'xml', + b"http://www.w3.org/1999/xhtml": b"html", + b"http://www.w3.org/1999/XSL/Transform": b"xsl", + b"http://www.w3.org/1999/02/22-rdf-syntax-ns#": b"rdf", + b"http://schemas.xmlsoap.org/wsdl/": b"wsdl", + # xml schema + b"http://www.w3.org/2001/XMLSchema": b"xs", + b"http://www.w3.org/2001/XMLSchema-instance": b"xsi", + # dublin core + b"http://purl.org/dc/elements/1.1/": b"dc", + # objectify + b"http://codespeak.net/lxml/objectify/pytype" : b"py", +} + +# To avoid runtime encoding overhead, we keep a Unicode copy +# of the uri-prefix mapping as (str, str) items view. +cdef object _DEFAULT_NAMESPACE_PREFIXES_ITEMS = [] + +cdef _update_default_namespace_prefixes_items(): + cdef bytes ns, prefix + global _DEFAULT_NAMESPACE_PREFIXES_ITEMS + _DEFAULT_NAMESPACE_PREFIXES_ITEMS = { + ns.decode('utf-8') : prefix.decode('utf-8') + for ns, prefix in _DEFAULT_NAMESPACE_PREFIXES.items() + }.items() + +_update_default_namespace_prefixes_items() + +cdef object _check_internal_prefix = re.compile(br"ns\d+$").match + +def register_namespace(prefix, uri): + """Registers a namespace prefix that newly created Elements in that + namespace will use. The registry is global, and any existing + mapping for either the given prefix or the namespace URI will be + removed. + """ + prefix_utf, uri_utf = _utf8(prefix), _utf8(uri) + if _check_internal_prefix(prefix_utf): + raise ValueError("Prefix format reserved for internal use") + _tagValidOrRaise(prefix_utf) + _uriValidOrRaise(uri_utf) + if (uri_utf == b"http://www.w3.org/XML/1998/namespace" and prefix_utf != b'xml' + or prefix_utf == b'xml' and uri_utf != b"http://www.w3.org/XML/1998/namespace"): + raise ValueError("Cannot change the 'xml' prefix of the XML namespace") + for k, v in list(_DEFAULT_NAMESPACE_PREFIXES.items()): + if k == uri_utf or v == prefix_utf: + del _DEFAULT_NAMESPACE_PREFIXES[k] + _DEFAULT_NAMESPACE_PREFIXES[uri_utf] = prefix_utf + _update_default_namespace_prefixes_items() + + +# Error superclass for ElementTree compatibility +cdef class Error(Exception): + pass + +# module level superclass for all exceptions +cdef class LxmlError(Error): + """Main exception base class for lxml. All other exceptions inherit from + this one. + """ + def __init__(self, message, error_log=None): + super(_Error, self).__init__(message) + if error_log is None: + self.error_log = __copyGlobalErrorLog() + else: + self.error_log = error_log.copy() + +cdef object _Error = Error + + +# superclass for all syntax errors +class LxmlSyntaxError(LxmlError, SyntaxError): + """Base class for all syntax errors. + """ + +cdef class C14NError(LxmlError): + """Error during C14N serialisation. + """ + +# version information +cdef tuple __unpackDottedVersion(version): + version_list = [] + l = (version.decode("ascii").replace('-', '.').split('.') + [0]*4)[:4] + for item in l: + try: + item = int(item) + except ValueError: + if item.startswith('dev'): + count = item[3:] + item = -300 + elif item.startswith('alpha'): + count = item[5:] + item = -200 + elif item.startswith('beta'): + count = item[4:] + item = -100 + else: + count = 0 + if count: + item += int(count) + version_list.append(item) + return tuple(version_list) + +cdef tuple __unpackIntVersion(int c_version, int base=100): + return ( + ((c_version // (base*base)) % base), + ((c_version // base) % base), + (c_version % base) + ) + +cdef int _LIBXML_VERSION_INT +try: + _LIBXML_VERSION_INT = int( + re.match('[0-9]+', (tree.xmlParserVersion).decode("ascii")).group(0)) +except Exception: + print("Unknown libxml2 version: " + (tree.xmlParserVersion).decode("latin1")) + _LIBXML_VERSION_INT = 0 + +LIBXML_VERSION = __unpackIntVersion(_LIBXML_VERSION_INT) +LIBXML_COMPILED_VERSION = __unpackIntVersion(tree.LIBXML_VERSION) +LXML_VERSION = __unpackDottedVersion(tree.LXML_VERSION_STRING) + +__version__ = tree.LXML_VERSION_STRING.decode("ascii") + +cdef extern from *: + """ + #ifdef ZLIB_VERNUM + #define __lxml_zlib_version (ZLIB_VERNUM >> 4) + #else + #define __lxml_zlib_version 0 + #endif + #ifdef _LIBICONV_VERSION + #define __lxml_iconv_version (_LIBICONV_VERSION << 8) + #else + #define __lxml_iconv_version 0 + #endif + """ + # zlib isn't included automatically by libxml2's headers + #long ZLIB_HEX_VERSION "__lxml_zlib_version" + long LIBICONV_HEX_VERSION "__lxml_iconv_version" + +#ZLIB_COMPILED_VERSION = __unpackIntVersion(ZLIB_HEX_VERSION, base=0x10) +ICONV_COMPILED_VERSION = __unpackIntVersion(LIBICONV_HEX_VERSION, base=0x100)[:2] + + +cdef extern from "libxml/xmlversion.h": + """ + static const char* const _lxml_lib_features[] = { +#ifdef LIBXML_HTML_ENABLED + "html", +#endif +#ifdef LIBXML_FTP_ENABLED + "ftp", +#endif +#ifdef LIBXML_HTTP_ENABLED + "http", +#endif +#ifdef LIBXML_CATALOG_ENABLED + "catalog", +#endif +#ifdef LIBXML_XPATH_ENABLED + "xpath", +#endif +#ifdef LIBXML_ICONV_ENABLED + "iconv", +#endif +#ifdef LIBXML_ICU_ENABLED + "icu", +#endif +#ifdef LIBXML_REGEXP_ENABLED + "regexp", +#endif +#ifdef LIBXML_SCHEMAS_ENABLED + "xmlschema", +#endif +#ifdef LIBXML_SCHEMATRON_ENABLED + "schematron", +#endif +#ifdef LIBXML_ZLIB_ENABLED + "zlib", +#endif +#ifdef LIBXML_LZMA_ENABLED + "lzma", +#endif + 0 + }; + """ + const char* const* _LXML_LIB_FEATURES "_lxml_lib_features" + + +cdef set _copy_lib_features(): + features = set() + feature = _LXML_LIB_FEATURES + while feature[0]: + features.add(feature[0].decode('ASCII')) + feature += 1 + return features + +LIBXML_COMPILED_FEATURES = _copy_lib_features() +LIBXML_FEATURES = { + feature_name for feature_id, feature_name in [ + #XML_WITH_THREAD = 1 + #XML_WITH_TREE = 2 + #XML_WITH_OUTPUT = 3 + #XML_WITH_PUSH = 4 + #XML_WITH_READER = 5 + #XML_WITH_PATTERN = 6 + #XML_WITH_WRITER = 7 + #XML_WITH_SAX1 = 8 + (xmlparser.XML_WITH_FTP, "ftp"), # XML_WITH_FTP = 9 + (xmlparser.XML_WITH_HTTP, "http"), # XML_WITH_HTTP = 10 + #XML_WITH_VALID = 11 + (xmlparser.XML_WITH_HTML, "html"), # XML_WITH_HTML = 12 + #XML_WITH_LEGACY = 13 + #XML_WITH_C14N = 14 + (xmlparser.XML_WITH_CATALOG, "catalog"), # XML_WITH_CATALOG = 15 + (xmlparser.XML_WITH_XPATH, "xpath"), # XML_WITH_XPATH = 16 + #XML_WITH_XPTR = 17 + #XML_WITH_XINCLUDE = 18 + (xmlparser.XML_WITH_ICONV, "iconv"), # XML_WITH_ICONV = 19 + #XML_WITH_ISO8859X = 20 + #XML_WITH_UNICODE = 21 + (xmlparser.XML_WITH_REGEXP, "regexp"), # XML_WITH_REGEXP = 22 + #XML_WITH_AUTOMATA = 23 + #XML_WITH_EXPR = 24 + (xmlparser.XML_WITH_SCHEMAS, "xmlschema"), # XML_WITH_SCHEMAS = 25 + (xmlparser.XML_WITH_SCHEMATRON, "schematron"), # XML_WITH_SCHEMATRON = 26 + #XML_WITH_MODULES = 27 + #XML_WITH_DEBUG = 28 + #XML_WITH_DEBUG_MEM = 29 + #XML_WITH_DEBUG_RUN = 30 # unused + (xmlparser.XML_WITH_ZLIB, "zlib"), # XML_WITH_ZLIB = 31 + (xmlparser.XML_WITH_ICU, "icu"), # XML_WITH_ICU = 32 + (xmlparser.XML_WITH_LZMA, "lzma"), # XML_WITH_LZMA = 33 + ] if xmlparser.xmlHasFeature(feature_id) +} + +cdef bint HAS_ZLIB_COMPRESSION = xmlparser.xmlHasFeature(xmlparser.XML_WITH_ZLIB) + + +# class for temporary storage of Python references, +# used e.g. for XPath results +@cython.final +@cython.internal +cdef class _TempStore: + cdef list _storage + def __init__(self): + self._storage = [] + + cdef int add(self, obj) except -1: + self._storage.append(obj) + return 0 + + cdef int clear(self) except -1: + del self._storage[:] + return 0 + + +# class for temporarily storing exceptions raised in extensions +@cython.internal +cdef class _ExceptionContext: + cdef object _exc_info + cdef int clear(self) except -1: + self._exc_info = None + return 0 + + cdef void _store_raised(self) noexcept: + try: + self._exc_info = sys.exc_info() + except BaseException as e: + self._store_exception(e) + finally: + return # and swallow any further exceptions + + cdef int _store_exception(self, exception) except -1: + self._exc_info = (exception, None, None) + return 0 + + cdef bint _has_raised(self) except -1: + return self._exc_info is not None + + cdef int _raise_if_stored(self) except -1: + if self._exc_info is None: + return 0 + type, value, traceback = self._exc_info + self._exc_info = None + if value is None and traceback is None: + raise type + else: + raise type, value, traceback + + +# type of a function that steps from node to node +ctypedef public xmlNode* (*_node_to_node_function)(xmlNode*) + + +################################################################################ +# Include submodules + +include "proxy.pxi" # Proxy handling (element backpointers/memory/etc.) +include "apihelpers.pxi" # Private helper functions +include "xmlerror.pxi" # Error and log handling + + +################################################################################ +# Public Python API + +@cython.final +@cython.freelist(8) +cdef public class _Document [ type LxmlDocumentType, object LxmlDocument ]: + """Internal base class to reference a libxml document. + + When instances of this class are garbage collected, the libxml + document is cleaned up. + """ + cdef int _ns_counter + cdef bytes _prefix_tail + cdef xmlDoc* _c_doc + cdef _BaseParser _parser + + def __dealloc__(self): + # if there are no more references to the document, it is safe + # to clean the whole thing up, as all nodes have a reference to + # the document + tree.xmlFreeDoc(self._c_doc) + + @cython.final + cdef getroot(self): + # return an element proxy for the document root + cdef xmlNode* c_node + c_node = tree.xmlDocGetRootElement(self._c_doc) + if c_node is NULL: + return None + return _elementFactory(self, c_node) + + @cython.final + cdef bint hasdoctype(self) noexcept: + # DOCTYPE gets parsed into internal subset (xmlDTD*) + return self._c_doc is not NULL and self._c_doc.intSubset is not NULL + + @cython.final + cdef getdoctype(self): + # get doctype info: root tag, public/system ID (or None if not known) + cdef tree.xmlDtd* c_dtd + cdef xmlNode* c_root_node + public_id = None + sys_url = None + c_dtd = self._c_doc.intSubset + if c_dtd is not NULL: + if c_dtd.ExternalID is not NULL: + public_id = funicode(c_dtd.ExternalID) + if c_dtd.SystemID is not NULL: + sys_url = funicode(c_dtd.SystemID) + c_dtd = self._c_doc.extSubset + if c_dtd is not NULL: + if not public_id and c_dtd.ExternalID is not NULL: + public_id = funicode(c_dtd.ExternalID) + if not sys_url and c_dtd.SystemID is not NULL: + sys_url = funicode(c_dtd.SystemID) + c_root_node = tree.xmlDocGetRootElement(self._c_doc) + if c_root_node is NULL: + root_name = None + else: + root_name = funicode(c_root_node.name) + return root_name, public_id, sys_url + + @cython.final + cdef getxmlinfo(self): + # return XML version and encoding (or None if not known) + cdef xmlDoc* c_doc = self._c_doc + if c_doc.version is NULL: + version = None + else: + version = funicode(c_doc.version) + if c_doc.encoding is NULL: + encoding = None + else: + encoding = funicode(c_doc.encoding) + return version, encoding + + @cython.final + cdef isstandalone(self): + # returns True for "standalone=true", + # False for "standalone=false", None if not provided + if self._c_doc.standalone == -1: + return None + else: + return (self._c_doc.standalone == 1) + + @cython.final + cdef bytes buildNewPrefix(self): + # get a new unique prefix ("nsX") for this document + cdef bytes ns + if self._ns_counter < len(_PREFIX_CACHE): + ns = _PREFIX_CACHE[self._ns_counter] + else: + ns = python.PyBytes_FromFormat("ns%d", self._ns_counter) + if self._prefix_tail is not None: + ns += self._prefix_tail + self._ns_counter += 1 + if self._ns_counter < 0: + # overflow! + self._ns_counter = 0 + if self._prefix_tail is None: + self._prefix_tail = b"A" + else: + self._prefix_tail += b"A" + return ns + + @cython.final + cdef xmlNs* _findOrBuildNodeNs(self, xmlNode* c_node, + const_xmlChar* c_href, const_xmlChar* c_prefix, + bint is_attribute) except NULL: + """Get or create namespace structure for a node. Reuses the prefix if + possible. + """ + cdef xmlNs* c_ns + cdef xmlNs* c_doc_ns + cdef python.PyObject* dict_result + if c_node.type != tree.XML_ELEMENT_NODE: + assert c_node.type == tree.XML_ELEMENT_NODE, \ + "invalid node type %d, expected %d" % ( + c_node.type, tree.XML_ELEMENT_NODE) + # look for existing ns declaration + c_ns = _searchNsByHref(c_node, c_href, is_attribute) + if c_ns is not NULL: + if is_attribute and c_ns.prefix is NULL: + # do not put namespaced attributes into the default + # namespace as this would break serialisation + pass + else: + return c_ns + + # none found => determine a suitable new prefix + if c_prefix is NULL: + dict_result = python.PyDict_GetItem( + _DEFAULT_NAMESPACE_PREFIXES, c_href) + if dict_result is not NULL: + prefix = dict_result + else: + prefix = self.buildNewPrefix() + c_prefix = _xcstr(prefix) + + # make sure the prefix is not in use already + while tree.xmlSearchNs(self._c_doc, c_node, c_prefix) is not NULL: + prefix = self.buildNewPrefix() + c_prefix = _xcstr(prefix) + + # declare the namespace and return it + c_ns = tree.xmlNewNs(c_node, c_href, c_prefix) + if c_ns is NULL: + raise MemoryError() + return c_ns + + @cython.final + cdef int _setNodeNs(self, xmlNode* c_node, const_xmlChar* c_href) except -1: + "Lookup namespace structure and set it for the node." + c_ns = self._findOrBuildNodeNs(c_node, c_href, NULL, 0) + tree.xmlSetNs(c_node, c_ns) + + +cdef tuple __initPrefixCache(): + cdef int i + return tuple([ python.PyBytes_FromFormat("ns%d", i) + for i in range(26) ]) + +cdef tuple _PREFIX_CACHE = __initPrefixCache() + + +cdef _Document _documentFactory(xmlDoc* c_doc, _BaseParser parser): + cdef _Document result + result = _Document.__new__(_Document) + result._c_doc = c_doc + result._ns_counter = 0 + result._prefix_tail = None + if parser is None: + parser = __GLOBAL_PARSER_CONTEXT.getDefaultParser() + result._parser = parser + return result + + +cdef object _find_invalid_public_id_characters = re.compile( + ur"[^\x20\x0D\x0Aa-zA-Z0-9'()+,./:=?;!*#@$_%-]+").search + + +cdef class DocInfo: + "Document information provided by parser and DTD." + cdef _Document _doc + def __cinit__(self, tree): + "Create a DocInfo object for an ElementTree object or root Element." + self._doc = _documentOrRaise(tree) + root_name, public_id, system_url = self._doc.getdoctype() + if not root_name and (public_id or system_url): + raise ValueError, "Could not find root node" + + @property + def root_name(self): + """Returns the name of the root node as defined by the DOCTYPE.""" + root_name, public_id, system_url = self._doc.getdoctype() + return root_name + + @cython.final + cdef tree.xmlDtd* _get_c_dtd(self): + """"Return the DTD. Create it if it does not yet exist.""" + cdef xmlDoc* c_doc = self._doc._c_doc + cdef xmlNode* c_root_node + cdef const_xmlChar* c_name + + if c_doc.intSubset: + return c_doc.intSubset + + c_root_node = tree.xmlDocGetRootElement(c_doc) + c_name = c_root_node.name if c_root_node else NULL + return tree.xmlCreateIntSubset(c_doc, c_name, NULL, NULL) + + def clear(self): + """Removes DOCTYPE and internal subset from the document.""" + cdef xmlDoc* c_doc = self._doc._c_doc + cdef tree.xmlNode* c_dtd = c_doc.intSubset + if c_dtd is NULL: + return + tree.xmlUnlinkNode(c_dtd) + tree.xmlFreeNode(c_dtd) + + property public_id: + """Public ID of the DOCTYPE. + + Mutable. May be set to a valid string or None. If a DTD does not + exist, setting this variable (even to None) will create one. + """ + def __get__(self): + root_name, public_id, system_url = self._doc.getdoctype() + return public_id + + def __set__(self, value): + cdef xmlChar* c_value = NULL + if value is not None: + match = _find_invalid_public_id_characters(value) + if match: + raise ValueError, f'Invalid character(s) {match.group(0)!r} in public_id.' + value = _utf8(value) + c_value = tree.xmlStrdup(_xcstr(value)) + if not c_value: + raise MemoryError() + + c_dtd = self._get_c_dtd() + if not c_dtd: + tree.xmlFree(c_value) + raise MemoryError() + if c_dtd.ExternalID: + tree.xmlFree(c_dtd.ExternalID) + c_dtd.ExternalID = c_value + + property system_url: + """System ID of the DOCTYPE. + + Mutable. May be set to a valid string or None. If a DTD does not + exist, setting this variable (even to None) will create one. + """ + def __get__(self): + root_name, public_id, system_url = self._doc.getdoctype() + return system_url + + def __set__(self, value): + cdef xmlChar* c_value = NULL + if value is not None: + bvalue = _utf8(value) + # sys_url may be any valid unicode string that can be + # enclosed in single quotes or quotes. + if b"'" in bvalue and b'"' in bvalue: + raise ValueError( + 'System URL may not contain both single (\') and double quotes (").') + c_value = tree.xmlStrdup(_xcstr(bvalue)) + if not c_value: + raise MemoryError() + + c_dtd = self._get_c_dtd() + if not c_dtd: + tree.xmlFree(c_value) + raise MemoryError() + if c_dtd.SystemID: + tree.xmlFree(c_dtd.SystemID) + c_dtd.SystemID = c_value + + @property + def xml_version(self): + """Returns the XML version as declared by the document.""" + xml_version, encoding = self._doc.getxmlinfo() + return xml_version + + @property + def encoding(self): + """Returns the encoding name as declared by the document.""" + xml_version, encoding = self._doc.getxmlinfo() + return encoding + + @property + def standalone(self): + """Returns the standalone flag as declared by the document. The possible + values are True (``standalone='yes'``), False + (``standalone='no'`` or flag not provided in the declaration), + and None (unknown or no declaration found). Note that a + normal truth test on this value will always tell if the + ``standalone`` flag was set to ``'yes'`` or not. + """ + return self._doc.isstandalone() + + property URL: + "The source URL of the document (or None if unknown)." + def __get__(self): + if self._doc._c_doc.URL is NULL: + return None + return _decodeFilename(self._doc._c_doc.URL) + def __set__(self, url): + url = _encodeFilename(url) + c_oldurl = self._doc._c_doc.URL + if url is None: + self._doc._c_doc.URL = NULL + else: + self._doc._c_doc.URL = tree.xmlStrdup(_xcstr(url)) + if c_oldurl is not NULL: + tree.xmlFree(c_oldurl) + + @property + def doctype(self): + """Returns a DOCTYPE declaration string for the document.""" + root_name, public_id, system_url = self._doc.getdoctype() + if system_url: + # If '"' in system_url, we must escape it with single + # quotes, otherwise escape with double quotes. If url + # contains both a single quote and a double quote, XML + # standard is being violated. + if '"' in system_url: + quoted_system_url = f"'{system_url}'" + else: + quoted_system_url = f'"{system_url}"' + if public_id: + if system_url: + return f'' + else: + return f'' + elif system_url: + return f'' + elif self._doc.hasdoctype(): + return f'' + else: + return '' + + @property + def internalDTD(self): + """Returns a DTD validator based on the internal subset of the document.""" + return _dtdFactory(self._doc._c_doc.intSubset) + + @property + def externalDTD(self): + """Returns a DTD validator based on the external subset of the document.""" + return _dtdFactory(self._doc._c_doc.extSubset) + + +@cython.no_gc_clear +cdef public class _Element [ type LxmlElementType, object LxmlElement ]: + """Element class. + + References a document object and a libxml node. + + By pointing to a Document instance, a reference is kept to + _Document as long as there is some pointer to a node in it. + """ + cdef _Document _doc + cdef xmlNode* _c_node + cdef object _tag + + def _init(self): + """_init(self) + + Called after object initialisation. Custom subclasses may override + this if they recursively call _init() in the superclasses. + """ + + @cython.linetrace(False) + @cython.profile(False) + def __dealloc__(self): + #print("trying to free node:", self._c_node) + #displayNode(self._c_node, 0) + if self._c_node is not NULL: + _unregisterProxy(self) + attemptDeallocation(self._c_node) + + # MANIPULATORS + + def __setitem__(self, x, value): + """__setitem__(self, x, value) + + Replaces the given subelement index or slice. + """ + cdef xmlNode* c_node = NULL + cdef xmlNode* c_next + cdef xmlDoc* c_source_doc + cdef _Element element + cdef bint left_to_right + cdef Py_ssize_t slicelength = 0, step = 0 + _assertValidNode(self) + if value is None: + raise ValueError, "cannot assign None" + if isinstance(x, slice): + # slice assignment + _findChildSlice(x, self._c_node, &c_node, &step, &slicelength) + if step > 0: + left_to_right = 1 + else: + left_to_right = 0 + step = -step + _replaceSlice(self, c_node, slicelength, step, left_to_right, value) + return + else: + # otherwise: normal item assignment + element = value + _assertValidNode(element) + c_node = _findChild(self._c_node, x) + if c_node is NULL: + raise IndexError, "list index out of range" + c_source_doc = element._c_node.doc + c_next = element._c_node.next + _removeText(c_node.next) + tree.xmlReplaceNode(c_node, element._c_node) + _moveTail(c_next, element._c_node) + moveNodeToDocument(self._doc, c_source_doc, element._c_node) + if not attemptDeallocation(c_node): + moveNodeToDocument(self._doc, c_node.doc, c_node) + + def __delitem__(self, x): + """__delitem__(self, x) + + Deletes the given subelement or a slice. + """ + cdef xmlNode* c_node = NULL + cdef xmlNode* c_next + cdef Py_ssize_t step = 0, slicelength = 0 + _assertValidNode(self) + if isinstance(x, slice): + # slice deletion + if _isFullSlice(x): + c_node = self._c_node.children + if c_node is not NULL: + if not _isElement(c_node): + c_node = _nextElement(c_node) + while c_node is not NULL: + c_next = _nextElement(c_node) + _removeNode(self._doc, c_node) + c_node = c_next + else: + _findChildSlice(x, self._c_node, &c_node, &step, &slicelength) + _deleteSlice(self._doc, c_node, slicelength, step) + else: + # item deletion + c_node = _findChild(self._c_node, x) + if c_node is NULL: + raise IndexError, f"index out of range: {x}" + _removeNode(self._doc, c_node) + + def __deepcopy__(self, memo): + "__deepcopy__(self, memo)" + return self.__copy__() + + def __copy__(self): + "__copy__(self)" + cdef xmlDoc* c_doc + cdef xmlNode* c_node + cdef _Document new_doc + _assertValidNode(self) + c_doc = _copyDocRoot(self._doc._c_doc, self._c_node) # recursive + new_doc = _documentFactory(c_doc, self._doc._parser) + root = new_doc.getroot() + if root is not None: + return root + # Comment/PI + c_node = c_doc.children + while c_node is not NULL and c_node.type != self._c_node.type: + c_node = c_node.next + if c_node is NULL: + return None + return _elementFactory(new_doc, c_node) + + def set(self, key, value): + """set(self, key, value) + + Sets an element attribute. + In HTML documents (not XML or XHTML), the value None is allowed and creates + an attribute without value (just the attribute name). + """ + _assertValidNode(self) + _setAttributeValue(self, key, value) + + def append(self, _Element element not None): + """append(self, element) + + Adds a subelement to the end of this element. + """ + _assertValidNode(self) + _assertValidNode(element) + _appendChild(self, element) + + def addnext(self, _Element element not None): + """addnext(self, element) + + Adds the element as a following sibling directly after this + element. + + This is normally used to set a processing instruction or comment after + the root node of a document. Note that tail text is automatically + discarded when adding at the root level. + """ + _assertValidNode(self) + _assertValidNode(element) + if self._c_node.parent != NULL and not _isElement(self._c_node.parent): + if element._c_node.type not in (tree.XML_PI_NODE, tree.XML_COMMENT_NODE): + raise TypeError, "Only processing instructions and comments can be siblings of the root element" + element.tail = None + _appendSibling(self, element) + + def addprevious(self, _Element element not None): + """addprevious(self, element) + + Adds the element as a preceding sibling directly before this + element. + + This is normally used to set a processing instruction or comment + before the root node of a document. Note that tail text is + automatically discarded when adding at the root level. + """ + _assertValidNode(self) + _assertValidNode(element) + if self._c_node.parent != NULL and not _isElement(self._c_node.parent): + if element._c_node.type != tree.XML_PI_NODE: + if element._c_node.type != tree.XML_COMMENT_NODE: + raise TypeError, "Only processing instructions and comments can be siblings of the root element" + element.tail = None + _prependSibling(self, element) + + def extend(self, elements): + """extend(self, elements) + + Extends the current children by the elements in the iterable. + """ + cdef _Element element + _assertValidNode(self) + for element in elements: + if element is None: + raise TypeError, "Node must not be None" + _assertValidNode(element) + _appendChild(self, element) + + def clear(self, bint keep_tail=False): + """clear(self, keep_tail=False) + + Resets an element. This function removes all subelements, clears + all attributes and sets the text and tail properties to None. + + Pass ``keep_tail=True`` to leave the tail text untouched. + """ + cdef xmlAttr* c_attr + cdef xmlAttr* c_attr_next + cdef xmlNode* c_node + cdef xmlNode* c_node_next + _assertValidNode(self) + c_node = self._c_node + # remove self.text and self.tail + _removeText(c_node.children) + if not keep_tail: + _removeText(c_node.next) + # remove all attributes + c_attr = c_node.properties + if c_attr: + c_node.properties = NULL + tree.xmlFreePropList(c_attr) + # remove all subelements + c_node = c_node.children + if c_node and not _isElement(c_node): + c_node = _nextElement(c_node) + while c_node is not NULL: + c_node_next = _nextElement(c_node) + _removeNode(self._doc, c_node) + c_node = c_node_next + + def insert(self, index: int, _Element element not None): + """insert(self, index, element) + + Inserts a subelement at the given position in this element + """ + cdef xmlNode* c_node + cdef xmlNode* c_next + cdef xmlDoc* c_source_doc + _assertValidNode(self) + _assertValidNode(element) + c_node = _findChild(self._c_node, index) + if c_node is NULL: + _appendChild(self, element) + return + # prevent cycles + if _isAncestorOrSame(element._c_node, self._c_node): + raise ValueError("cannot append parent to itself") + c_source_doc = element._c_node.doc + c_next = element._c_node.next + tree.xmlAddPrevSibling(c_node, element._c_node) + _moveTail(c_next, element._c_node) + moveNodeToDocument(self._doc, c_source_doc, element._c_node) + + def remove(self, _Element element not None): + """remove(self, element) + + Removes a matching subelement. Unlike the find methods, this + method compares elements based on identity, not on tag value + or contents. + """ + cdef xmlNode* c_node + cdef xmlNode* c_next + _assertValidNode(self) + _assertValidNode(element) + c_node = element._c_node + if c_node.parent is not self._c_node: + raise ValueError, "Element is not a child of this node." + c_next = element._c_node.next + tree.xmlUnlinkNode(c_node) + _moveTail(c_next, c_node) + # fix namespace declarations + moveNodeToDocument(self._doc, c_node.doc, c_node) + + def replace(self, _Element old_element not None, + _Element new_element not None): + """replace(self, old_element, new_element) + + Replaces a subelement with the element passed as second argument. + """ + cdef xmlNode* c_old_node + cdef xmlNode* c_old_next + cdef xmlNode* c_new_node + cdef xmlNode* c_new_next + cdef xmlDoc* c_source_doc + _assertValidNode(self) + _assertValidNode(old_element) + _assertValidNode(new_element) + c_old_node = old_element._c_node + if c_old_node.parent is not self._c_node: + raise ValueError, "Element is not a child of this node." + c_new_node = new_element._c_node + # prevent cycles + if _isAncestorOrSame(c_new_node, self._c_node): + raise ValueError("cannot append parent to itself") + # replace node + c_old_next = c_old_node.next + c_new_next = c_new_node.next + c_source_doc = c_new_node.doc + tree.xmlReplaceNode(c_old_node, c_new_node) + _moveTail(c_new_next, c_new_node) + _moveTail(c_old_next, c_old_node) + moveNodeToDocument(self._doc, c_source_doc, c_new_node) + # fix namespace declarations + moveNodeToDocument(self._doc, c_old_node.doc, c_old_node) + + # PROPERTIES + property tag: + """Element tag + """ + def __get__(self): + if self._tag is not None: + return self._tag + _assertValidNode(self) + self._tag = _namespacedName(self._c_node) + return self._tag + + def __set__(self, value): + cdef _BaseParser parser + _assertValidNode(self) + ns, name = _getNsTag(value) + parser = self._doc._parser + if parser is not None and parser._for_html: + _htmlTagValidOrRaise(name) + else: + _tagValidOrRaise(name) + self._tag = value + tree.xmlNodeSetName(self._c_node, _xcstr(name)) + if ns is None: + self._c_node.ns = NULL + else: + self._doc._setNodeNs(self._c_node, _xcstr(ns)) + + @property + def attrib(self): + """Element attribute dictionary. Where possible, use get(), set(), + keys(), values() and items() to access element attributes. + """ + return _Attrib.__new__(_Attrib, self) + + property text: + """Text before the first subelement. This is either a string or + the value None, if there was no text. + """ + def __get__(self): + _assertValidNode(self) + return _collectText(self._c_node.children) + + def __set__(self, value): + _assertValidNode(self) + if isinstance(value, QName): + value = _resolveQNameText(self, value).decode('utf8') + _setNodeText(self._c_node, value) + + # using 'del el.text' is the wrong thing to do + #def __del__(self): + # _setNodeText(self._c_node, None) + + property tail: + """Text after this element's end tag, but before the next sibling + element's start tag. This is either a string or the value None, if + there was no text. + """ + def __get__(self): + _assertValidNode(self) + return _collectText(self._c_node.next) + + def __set__(self, value): + _assertValidNode(self) + _setTailText(self._c_node, value) + + # using 'del el.tail' is the wrong thing to do + #def __del__(self): + # _setTailText(self._c_node, None) + + # not in ElementTree, read-only + @property + def prefix(self): + """Namespace prefix or None. + """ + if self._c_node.ns is not NULL: + if self._c_node.ns.prefix is not NULL: + return funicode(self._c_node.ns.prefix) + return None + + # not in ElementTree, read-only + property sourceline: + """Original line number as found by the parser or None if unknown. + """ + def __get__(self): + cdef long line + _assertValidNode(self) + line = tree.xmlGetLineNo(self._c_node) + return line if line > 0 else None + + def __set__(self, line): + _assertValidNode(self) + if line <= 0: + self._c_node.line = 0 + else: + self._c_node.line = line + + # not in ElementTree, read-only + @property + def nsmap(self): + """Namespace prefix->URI mapping known in the context of this + Element. This includes all namespace declarations of the + parents. + + Note that changing the returned dict has no effect on the Element. + """ + _assertValidNode(self) + return _build_nsmap(self._c_node) + + # not in ElementTree, read-only + property base: + """The base URI of the Element (xml:base or HTML base URL). + None if the base URI is unknown. + + Note that the value depends on the URL of the document that + holds the Element if there is no xml:base attribute on the + Element or its ancestors. + + Setting this property will set an xml:base attribute on the + Element, regardless of the document type (XML or HTML). + """ + def __get__(self): + _assertValidNode(self) + c_base = tree.xmlNodeGetBase(self._doc._c_doc, self._c_node) + if c_base is NULL: + if self._doc._c_doc.URL is NULL: + return None + return _decodeFilename(self._doc._c_doc.URL) + try: + base = _decodeFilename(c_base) + finally: + tree.xmlFree(c_base) + return base + + def __set__(self, url): + _assertValidNode(self) + if url is None: + c_base = NULL + else: + url = _encodeFilename(url) + c_base = _xcstr(url) + tree.xmlNodeSetBase(self._c_node, c_base) + + # ACCESSORS + def __repr__(self): + "__repr__(self)" + return "" % (self.tag, id(self)) + + def __getitem__(self, x): + """Returns the subelement at the given position or the requested + slice. + """ + cdef xmlNode* c_node = NULL + cdef Py_ssize_t step = 0, slicelength = 0 + cdef Py_ssize_t c, i + cdef _node_to_node_function next_element + cdef list result + _assertValidNode(self) + if isinstance(x, slice): + # slicing + if _isFullSlice(x): + return _collectChildren(self) + _findChildSlice(x, self._c_node, &c_node, &step, &slicelength) + if c_node is NULL: + return [] + if step > 0: + next_element = _nextElement + else: + step = -step + next_element = _previousElement + result = [] + c = 0 + while c_node is not NULL and c < slicelength: + result.append(_elementFactory(self._doc, c_node)) + c += 1 + for i in range(step): + c_node = next_element(c_node) + if c_node is NULL: + break + return result + else: + # indexing + c_node = _findChild(self._c_node, x) + if c_node is NULL: + raise IndexError, "list index out of range" + return _elementFactory(self._doc, c_node) + + def __len__(self): + """__len__(self) + + Returns the number of subelements. + """ + _assertValidNode(self) + return _countElements(self._c_node.children) + + def __bool__(self): + """__bool__(self)""" + import warnings + warnings.warn( + "Truth-testing of elements was a source of confusion and will always " + "return True in future versions. " + "Use specific 'len(elem)' or 'elem is not None' test instead.", + FutureWarning + ) + # emulate old behaviour + _assertValidNode(self) + return _hasChild(self._c_node) + + def __contains__(self, element): + "__contains__(self, element)" + cdef xmlNode* c_node + _assertValidNode(self) + if not isinstance(element, _Element): + return 0 + c_node = (<_Element>element)._c_node + return c_node is not NULL and c_node.parent is self._c_node + + def __iter__(self): + "__iter__(self)" + return ElementChildIterator(self) + + def __reversed__(self): + "__reversed__(self)" + return ElementChildIterator(self, reversed=True) + + def index(self, child: _Element, start: int = None, stop: int = None): + """index(self, child, start=None, stop=None) + + Find the position of the child within the parent. + + This method is not part of the original ElementTree API. + """ + cdef Py_ssize_t k, l + cdef Py_ssize_t c_start, c_stop + cdef xmlNode* c_child + cdef xmlNode* c_start_node + _assertValidNode(self) + _assertValidNode(child) + c_child = child._c_node + if c_child.parent is not self._c_node: + raise ValueError, "Element is not a child of this node." + + # handle the unbounded search straight away (normal case) + if stop is None and (start is None or start == 0): + k = 0 + c_child = c_child.prev + while c_child is not NULL: + if _isElement(c_child): + k += 1 + c_child = c_child.prev + return k + + # check indices + if start is None: + c_start = 0 + else: + c_start = start + if stop is None: + c_stop = 0 + else: + c_stop = stop + if c_stop == 0 or \ + c_start >= c_stop and (c_stop > 0 or c_start < 0): + raise ValueError, "list.index(x): x not in slice" + + # for negative slice indices, check slice before searching index + if c_start < 0 or c_stop < 0: + # start from right, at most up to leftmost(c_start, c_stop) + if c_start < c_stop: + k = -c_start + else: + k = -c_stop + c_start_node = self._c_node.last + l = 1 + while c_start_node != c_child and l < k: + if _isElement(c_start_node): + l += 1 + c_start_node = c_start_node.prev + if c_start_node == c_child: + # found! before slice end? + if c_stop < 0 and l <= -c_stop: + raise ValueError, "list.index(x): x not in slice" + elif c_start < 0: + raise ValueError, "list.index(x): x not in slice" + + # now determine the index backwards from child + c_child = c_child.prev + k = 0 + if c_stop > 0: + # we can optimize: stop after c_stop elements if not found + while c_child != NULL and k < c_stop: + if _isElement(c_child): + k += 1 + c_child = c_child.prev + if k < c_stop: + return k + else: + # traverse all + while c_child != NULL: + if _isElement(c_child): + k = k + 1 + c_child = c_child.prev + if c_start > 0: + if k >= c_start: + return k + else: + return k + if c_start != 0 or c_stop != 0: + raise ValueError, "list.index(x): x not in slice" + else: + raise ValueError, "list.index(x): x not in list" + + def get(self, key, default=None): + """get(self, key, default=None) + + Gets an element attribute. + """ + _assertValidNode(self) + return _getAttributeValue(self, key, default) + + def keys(self): + """keys(self) + + Gets a list of attribute names. The names are returned in an + arbitrary order (just like for an ordinary Python dictionary). + """ + _assertValidNode(self) + return _collectAttributes(self._c_node, 1) + + def values(self): + """values(self) + + Gets element attribute values as a sequence of strings. The + attributes are returned in an arbitrary order. + """ + _assertValidNode(self) + return _collectAttributes(self._c_node, 2) + + def items(self): + """items(self) + + Gets element attributes, as a sequence. The attributes are returned in + an arbitrary order. + """ + _assertValidNode(self) + return _collectAttributes(self._c_node, 3) + + def getchildren(self): + """getchildren(self) + + Returns all direct children. The elements are returned in document + order. + + :deprecated: Note that this method has been deprecated as of + ElementTree 1.3 and lxml 2.0. New code should use + ``list(element)`` or simply iterate over elements. + """ + _assertValidNode(self) + return _collectChildren(self) + + def getparent(self): + """getparent(self) + + Returns the parent of this element or None for the root element. + """ + cdef xmlNode* c_node + #_assertValidNode(self) # not needed + c_node = _parentElement(self._c_node) + if c_node is NULL: + return None + return _elementFactory(self._doc, c_node) + + def getnext(self): + """getnext(self) + + Returns the following sibling of this element or None. + """ + cdef xmlNode* c_node + #_assertValidNode(self) # not needed + c_node = _nextElement(self._c_node) + if c_node is NULL: + return None + return _elementFactory(self._doc, c_node) + + def getprevious(self): + """getprevious(self) + + Returns the preceding sibling of this element or None. + """ + cdef xmlNode* c_node + #_assertValidNode(self) # not needed + c_node = _previousElement(self._c_node) + if c_node is NULL: + return None + return _elementFactory(self._doc, c_node) + + def itersiblings(self, tag=None, *tags, preceding=False): + """itersiblings(self, tag=None, *tags, preceding=False) + + Iterate over the following or preceding siblings of this element. + + The direction is determined by the 'preceding' keyword which + defaults to False, i.e. forward iteration over the following + siblings. When True, the iterator yields the preceding + siblings in reverse document order, i.e. starting right before + the current element and going backwards. + + Can be restricted to find only elements with specific tags, + see `iter`. + """ + if preceding: + if self._c_node and not self._c_node.prev: + return ITER_EMPTY + elif self._c_node and not self._c_node.next: + return ITER_EMPTY + if tag is not None: + tags += (tag,) + return SiblingsIterator(self, tags, preceding=preceding) + + def iterancestors(self, tag=None, *tags): + """iterancestors(self, tag=None, *tags) + + Iterate over the ancestors of this element (from parent to parent). + + Can be restricted to find only elements with specific tags, + see `iter`. + """ + if self._c_node and not self._c_node.parent: + return ITER_EMPTY + if tag is not None: + tags += (tag,) + return AncestorsIterator(self, tags) + + def iterdescendants(self, tag=None, *tags): + """iterdescendants(self, tag=None, *tags) + + Iterate over the descendants of this element in document order. + + As opposed to ``el.iter()``, this iterator does not yield the element + itself. The returned elements can be restricted to find only elements + with specific tags, see `iter`. + """ + if self._c_node and not self._c_node.children: + return ITER_EMPTY + if tag is not None: + tags += (tag,) + return ElementDepthFirstIterator(self, tags, inclusive=False) + + def iterchildren(self, tag=None, *tags, reversed=False): + """iterchildren(self, tag=None, *tags, reversed=False) + + Iterate over the children of this element. + + As opposed to using normal iteration on this element, the returned + elements can be reversed with the 'reversed' keyword and restricted + to find only elements with specific tags, see `iter`. + """ + if self._c_node and not self._c_node.children: + return ITER_EMPTY + if tag is not None: + tags += (tag,) + return ElementChildIterator(self, tags, reversed=reversed) + + def getroottree(self): + """getroottree(self) + + Return an ElementTree for the root node of the document that + contains this element. + + This is the same as following element.getparent() up the tree until it + returns None (for the root element) and then build an ElementTree for + the last parent that was returned.""" + _assertValidDoc(self._doc) + return _elementTreeFactory(self._doc, None) + + def getiterator(self, tag=None, *tags): + """getiterator(self, tag=None, *tags) + + Returns a sequence or iterator of all elements in the subtree in + document order (depth first pre-order), starting with this + element. + + Can be restricted to find only elements with specific tags, + see `iter`. + + :deprecated: Note that this method is deprecated as of + ElementTree 1.3 and lxml 2.0. It returns an iterator in + lxml, which diverges from the original ElementTree + behaviour. If you want an efficient iterator, use the + ``element.iter()`` method instead. You should only use this + method in new code if you require backwards compatibility + with older versions of lxml or ElementTree. + """ + if tag is not None: + tags += (tag,) + return ElementDepthFirstIterator(self, tags) + + def iter(self, tag=None, *tags): + """iter(self, tag=None, *tags) + + Iterate over all elements in the subtree in document order (depth + first pre-order), starting with this element. + + Can be restricted to find only elements with specific tags: + pass ``"{ns}localname"`` as tag. Either or both of ``ns`` and + ``localname`` can be ``*`` for a wildcard; ``ns`` can be empty + for no namespace. ``"localname"`` is equivalent to ``"{}localname"`` + (i.e. no namespace) but ``"*"`` is ``"{*}*"`` (any or no namespace), + not ``"{}*"``. + + You can also pass the Element, Comment, ProcessingInstruction and + Entity factory functions to look only for the specific element type. + + Passing multiple tags (or a sequence of tags) instead of a single tag + will let the iterator return all elements matching any of these tags, + in document order. + """ + if tag is not None: + tags += (tag,) + return ElementDepthFirstIterator(self, tags) + + def itertext(self, tag=None, *tags, with_tail=True): + """itertext(self, tag=None, *tags, with_tail=True) + + Iterates over the text content of a subtree. + + You can pass tag names to restrict text content to specific elements, + see `iter`. + + You can set the ``with_tail`` keyword argument to ``False`` to skip + over tail text. + """ + if tag is not None: + tags += (tag,) + return ElementTextIterator(self, tags, with_tail=with_tail) + + def makeelement(self, _tag, attrib=None, nsmap=None, **_extra): + """makeelement(self, _tag, attrib=None, nsmap=None, **_extra) + + Creates a new element associated with the same document. + """ + _assertValidDoc(self._doc) + return _makeElement(_tag, NULL, self._doc, None, None, None, + attrib, nsmap, _extra) + + def find(self, path, namespaces=None): + """find(self, path, namespaces=None) + + Finds the first matching subelement, by tag name or path. + + The optional ``namespaces`` argument accepts a + prefix-to-namespace mapping that allows the usage of XPath + prefixes in the path expression. + """ + if isinstance(path, QName): + path = (path).text + return _elementpath.find(self, path, namespaces, with_prefixes=not _isHtmlDocument(self)) + + def findtext(self, path, default=None, namespaces=None): + """findtext(self, path, default=None, namespaces=None) + + Finds text for the first matching subelement, by tag name or path. + + The optional ``namespaces`` argument accepts a + prefix-to-namespace mapping that allows the usage of XPath + prefixes in the path expression. + """ + if isinstance(path, QName): + path = (path).text + return _elementpath.findtext(self, path, default, namespaces, with_prefixes=not _isHtmlDocument(self)) + + def findall(self, path, namespaces=None): + """findall(self, path, namespaces=None) + + Finds all matching subelements, by tag name or path. + + The optional ``namespaces`` argument accepts a + prefix-to-namespace mapping that allows the usage of XPath + prefixes in the path expression. + """ + if isinstance(path, QName): + path = (path).text + return _elementpath.findall(self, path, namespaces, with_prefixes=not _isHtmlDocument(self)) + + def iterfind(self, path, namespaces=None): + """iterfind(self, path, namespaces=None) + + Iterates over all matching subelements, by tag name or path. + + The optional ``namespaces`` argument accepts a + prefix-to-namespace mapping that allows the usage of XPath + prefixes in the path expression. + """ + if isinstance(path, QName): + path = (path).text + return _elementpath.iterfind(self, path, namespaces, with_prefixes=not _isHtmlDocument(self)) + + def xpath(self, _path, *, namespaces=None, extensions=None, + smart_strings=True, **_variables): + """xpath(self, _path, namespaces=None, extensions=None, smart_strings=True, **_variables) + + Evaluate an xpath expression using the element as context node. + """ + evaluator = XPathElementEvaluator(self, namespaces=namespaces, + extensions=extensions, + smart_strings=smart_strings) + return evaluator(_path, **_variables) + + def cssselect(self, expr, *, translator='xml'): + """ + Run the CSS expression on this element and its children, + returning a list of the results. + + Equivalent to lxml.cssselect.CSSSelect(expr)(self) -- note + that pre-compiling the expression can provide a substantial + speedup. + """ + # Do the import here to make the dependency optional. + from lxml.cssselect import CSSSelector + return CSSSelector(expr, translator=translator)(self) + + +@cython.linetrace(False) +cdef _Element _elementFactory(_Document doc, xmlNode* c_node): + cdef _Element result + result = getProxy(c_node) + if result is not None: + return result + if c_node is NULL: + return None + + element_class = LOOKUP_ELEMENT_CLASS( + ELEMENT_CLASS_LOOKUP_STATE, doc, c_node) + if type(element_class) is not type: + if not isinstance(element_class, type): + raise TypeError(f"Element class is not a type, got {type(element_class)}") + if hasProxy(c_node): + # prevent re-entry race condition - we just called into Python + return getProxy(c_node) + result = element_class.__new__(element_class) + if hasProxy(c_node): + # prevent re-entry race condition - we just called into Python + result._c_node = NULL + return getProxy(c_node) + + _registerProxy(result, doc, c_node) + if element_class is not _Element: + result._init() + return result + + +@cython.internal +cdef class __ContentOnlyElement(_Element): + cdef int _raiseImmutable(self) except -1: + raise TypeError, "this element does not have children or attributes" + + def set(self, key, value): + "set(self, key, value)" + self._raiseImmutable() + + def append(self, value): + "append(self, value)" + self._raiseImmutable() + + def insert(self, index, value): + "insert(self, index, value)" + self._raiseImmutable() + + def __setitem__(self, index, value): + "__setitem__(self, index, value)" + self._raiseImmutable() + + @property + def attrib(self): + return IMMUTABLE_EMPTY_MAPPING + + property text: + def __get__(self): + _assertValidNode(self) + return funicodeOrEmpty(self._c_node.content) + + def __set__(self, value): + cdef tree.xmlDict* c_dict + _assertValidNode(self) + if value is None: + c_text = NULL + else: + value = _utf8(value) + c_text = _xcstr(value) + tree.xmlNodeSetContent(self._c_node, c_text) + + # ACCESSORS + def __getitem__(self, x): + "__getitem__(self, x)" + if isinstance(x, slice): + return [] + else: + raise IndexError, "list index out of range" + + def __len__(self): + "__len__(self)" + return 0 + + def get(self, key, default=None): + "get(self, key, default=None)" + return None + + def keys(self): + "keys(self)" + return [] + + def items(self): + "items(self)" + return [] + + def values(self): + "values(self)" + return [] + +cdef class _Comment(__ContentOnlyElement): + @property + def tag(self): + return Comment + + def __repr__(self): + return "" % self.text + +cdef class _ProcessingInstruction(__ContentOnlyElement): + @property + def tag(self): + return ProcessingInstruction + + property target: + # not in ElementTree + def __get__(self): + _assertValidNode(self) + return funicode(self._c_node.name) + + def __set__(self, value): + _assertValidNode(self) + value = _utf8(value) + c_text = _xcstr(value) + tree.xmlNodeSetName(self._c_node, c_text) + + def __repr__(self): + text = self.text + if text: + return "" % (self.target, text) + else: + return "" % self.target + + def get(self, key, default=None): + """get(self, key, default=None) + + Try to parse pseudo-attributes from the text content of the + processing instruction, search for one with the given key as + name and return its associated value. + + Note that this is only a convenience method for the most + common case that all text content is structured in + attribute-like name-value pairs with properly quoted values. + It is not guaranteed to work for all possible text content. + """ + return self.attrib.get(key, default) + + @property + def attrib(self): + """Returns a dict containing all pseudo-attributes that can be + parsed from the text content of this processing instruction. + Note that modifying the dict currently has no effect on the + XML node, although this is not guaranteed to stay this way. + """ + return { attr : (value1 or value2) + for attr, value1, value2 in _FIND_PI_ATTRIBUTES(' ' + self.text) } + +cdef object _FIND_PI_ATTRIBUTES = re.compile(r'\s+(\w+)\s*=\s*(?:\'([^\']*)\'|"([^"]*)")', re.U).findall + +cdef class _Entity(__ContentOnlyElement): + @property + def tag(self): + return Entity + + property name: + # not in ElementTree + def __get__(self): + _assertValidNode(self) + return funicode(self._c_node.name) + + def __set__(self, value): + _assertValidNode(self) + value_utf = _utf8(value) + if b'&' in value_utf or b';' in value_utf: + raise ValueError, f"Invalid entity name '{value}'" + tree.xmlNodeSetName(self._c_node, _xcstr(value_utf)) + + @property + def text(self): + # FIXME: should this be None or '&[VALUE];' or the resolved + # entity value ? + _assertValidNode(self) + return f'&{funicode(self._c_node.name)};' + + def __repr__(self): + return "&%s;" % self.name + + +cdef class QName: + """QName(text_or_uri_or_element, tag=None) + + QName wrapper for qualified XML names. + + Pass a tag name by itself or a namespace URI and a tag name to + create a qualified name. Alternatively, pass an Element to + extract its tag name. ``None`` as first argument is ignored in + order to allow for generic 2-argument usage. + + The ``text`` property holds the qualified name in + ``{namespace}tagname`` notation. The ``namespace`` and + ``localname`` properties hold the respective parts of the tag + name. + + You can pass QName objects wherever a tag name is expected. Also, + setting Element text from a QName will resolve the namespace prefix + on assignment and set a qualified text value. This is helpful in XML + languages like SOAP or XML-Schema that use prefixed tag names in + their text content. + """ + cdef readonly unicode text + cdef readonly unicode localname + cdef readonly unicode namespace + def __init__(self, text_or_uri_or_element, tag=None): + if text_or_uri_or_element is None: + # Allow None as no namespace. + text_or_uri_or_element, tag = tag, None + if not _isString(text_or_uri_or_element): + if isinstance(text_or_uri_or_element, _Element): + text_or_uri_or_element = (<_Element>text_or_uri_or_element).tag + if not _isString(text_or_uri_or_element): + raise ValueError, f"Invalid input tag of type {type(text_or_uri_or_element)!r}" + elif isinstance(text_or_uri_or_element, QName): + text_or_uri_or_element = (text_or_uri_or_element).text + elif text_or_uri_or_element is not None: + text_or_uri_or_element = unicode(text_or_uri_or_element) + else: + raise ValueError, f"Invalid input tag of type {type(text_or_uri_or_element)!r}" + + ns_utf, tag_utf = _getNsTag(text_or_uri_or_element) + if tag is not None: + # either ('ns', 'tag') or ('{ns}oldtag', 'newtag') + if ns_utf is None: + ns_utf = tag_utf # case 1: namespace ended up as tag name + tag_utf = _utf8(tag) + _tagValidOrRaise(tag_utf) + self.localname = (tag_utf).decode('utf8') + if ns_utf is None: + self.namespace = None + self.text = self.localname + else: + self.namespace = (ns_utf).decode('utf8') + self.text = "{%s}%s" % (self.namespace, self.localname) + def __str__(self): + return self.text + def __hash__(self): + return hash(self.text) + def __richcmp__(self, other, int op): + try: + if type(other) is QName: + other = (other).text + elif not isinstance(other, unicode): + other = unicode(other) + except (ValueError, UnicodeDecodeError): + return NotImplemented + return python.PyObject_RichCompare(self.text, other, op) + + +cdef public class _ElementTree [ type LxmlElementTreeType, + object LxmlElementTree ]: + cdef _Document _doc + cdef _Element _context_node + + # Note that _doc is only used to store the original document if we do not + # have a _context_node. All methods should prefer self._context_node._doc + # to honour tree restructuring. _doc can happily be None! + + @cython.final + cdef int _assertHasRoot(self) except -1: + """We have to take care here: the document may not have a root node! + This can happen if ElementTree() is called without any argument and + the caller 'forgets' to call parse() afterwards, so this is a bug in + the caller program. + """ + assert self._context_node is not None, \ + "ElementTree not initialized, missing root" + return 0 + + def parse(self, source, _BaseParser parser=None, *, base_url=None): + """parse(self, source, parser=None, base_url=None) + + Updates self with the content of source and returns its root. + """ + cdef _Document doc = None + try: + doc = _parseDocument(source, parser, base_url) + except _TargetParserResult as result_container: + # raises a TypeError if we don't get an _Element + self._context_node = result_container.result + else: + self._context_node = doc.getroot() + self._doc = None if self._context_node is not None else doc + return self._context_node + + def _setroot(self, _Element root not None): + """_setroot(self, root) + + Relocate the ElementTree to a new root node. + """ + _assertValidNode(root) + if root._c_node.type != tree.XML_ELEMENT_NODE: + raise TypeError, "Only elements can be the root of an ElementTree" + self._context_node = root + self._doc = None + + def getroot(self): + """getroot(self) + + Gets the root element for this tree. + """ + return self._context_node + + def __copy__(self): + return _elementTreeFactory(self._doc, self._context_node) + + def __deepcopy__(self, memo): + cdef _Element root + cdef _Document doc + cdef xmlDoc* c_doc + if self._context_node is not None: + root = self._context_node.__copy__() + assert root is not None + _assertValidNode(root) + _copyNonElementSiblings(self._context_node._c_node, root._c_node) + return _elementTreeFactory(None, root) + elif self._doc is not None: + _assertValidDoc(self._doc) + c_doc = tree.xmlCopyDoc(self._doc._c_doc, 1) + if c_doc is NULL: + raise MemoryError() + doc = _documentFactory(c_doc, self._doc._parser) + return _elementTreeFactory(doc, None) + else: + # so what ... + return self + + # not in ElementTree + @property + def docinfo(self) -> DocInfo: + """Information about the document provided by parser and DTD.""" + self._assertHasRoot() + return DocInfo(self._context_node._doc) + + # not in ElementTree, read-only + @property + def parser(self): + """The parser that was used to parse the document in this ElementTree. + """ + if self._context_node is not None and \ + self._context_node._doc is not None: + return self._context_node._doc._parser + if self._doc is not None: + return self._doc._parser + return None + + def write(self, file, *, encoding=None, method="xml", + bint pretty_print=False, xml_declaration=None, bint with_tail=True, + standalone=None, doctype=None, compression=0, + bint exclusive=False, inclusive_ns_prefixes=None, + bint with_comments=True, bint strip_text=False, + docstring=None): + """write(self, file, encoding=None, method="xml", + pretty_print=False, xml_declaration=None, with_tail=True, + standalone=None, doctype=None, compression=0, + exclusive=False, inclusive_ns_prefixes=None, + with_comments=True, strip_text=False) + + Write the tree to a filename, file or file-like object. + + Defaults to ASCII encoding and writing a declaration as needed. + + The keyword argument 'method' selects the output method: + 'xml', 'html', 'text', 'c14n' or 'c14n2'. Default is 'xml'. + + With ``method="c14n"`` (C14N version 1), the options ``exclusive``, + ``with_comments`` and ``inclusive_ns_prefixes`` request exclusive + C14N, include comments, and list the inclusive prefixes respectively. + + With ``method="c14n2"`` (C14N version 2), the ``with_comments`` and + ``strip_text`` options control the output of comments and text space + according to C14N 2.0. + + Passing a boolean value to the ``standalone`` option will + output an XML declaration with the corresponding + ``standalone`` flag. + + The ``doctype`` option allows passing in a plain string that will + be serialised before the XML tree. Note that passing in non + well-formed content here will make the XML output non well-formed. + Also, an existing doctype in the document tree will not be removed + when serialising an ElementTree instance. + + The ``compression`` option enables GZip compression level 1-9. + + The ``inclusive_ns_prefixes`` should be a list of namespace strings + (i.e. ['xs', 'xsi']) that will be promoted to the top-level element + during exclusive C14N serialisation. This parameter is ignored if + exclusive mode=False. + + If exclusive=True and no list is provided, a namespace will only be + rendered if it is used by the immediate parent or one of its attributes + and its prefix and values have not already been rendered by an ancestor + of the namespace node's parent element. + """ + cdef bint write_declaration + cdef int is_standalone + + self._assertHasRoot() + _assertValidNode(self._context_node) + if compression is None or compression < 0: + compression = 0 + + # C14N serialisation + if method in ('c14n', 'c14n2'): + if encoding is not None: + raise ValueError("Cannot specify encoding with C14N") + if xml_declaration: + raise ValueError("Cannot enable XML declaration in C14N") + + if method == 'c14n': + _tofilelikeC14N(file, self._context_node, exclusive, with_comments, + compression, inclusive_ns_prefixes) + else: # c14n2 + with _open_utf8_file(file, compression=compression) as f: + target = C14NWriterTarget( + f.write, with_comments=with_comments, strip_text=strip_text) + _tree_to_target(self, target) + return + + if not with_comments: + raise ValueError("Can only discard comments in C14N serialisation") + # suppress decl. in default case (purely for ElementTree compatibility) + if xml_declaration is not None: + write_declaration = xml_declaration + if encoding is None: + encoding = 'ASCII' + else: + encoding = encoding.upper() + elif encoding is None: + encoding = 'ASCII' + write_declaration = 0 + else: + encoding = encoding.upper() + write_declaration = encoding not in ( + 'US-ASCII', 'ASCII', 'UTF8', 'UTF-8') + if standalone is None: + is_standalone = -1 + elif standalone: + write_declaration = 1 + is_standalone = 1 + else: + write_declaration = 1 + is_standalone = 0 + + if docstring is not None and doctype is None: + import warnings + warnings.warn( + "The 'docstring' option is deprecated. Use 'doctype' instead.", + DeprecationWarning) + doctype = docstring + + _tofilelike(file, self._context_node, encoding, doctype, method, + write_declaration, 1, pretty_print, with_tail, + is_standalone, compression) + + def getpath(self, _Element element not None): + """getpath(self, element) + + Returns a structural, absolute XPath expression to find the element. + + For namespaced elements, the expression uses prefixes from the + document, which therefore need to be provided in order to make any + use of the expression in XPath. + + Also see the method getelementpath(self, element), which returns a + self-contained ElementPath expression. + """ + cdef _Document doc + cdef _Element root + cdef xmlDoc* c_doc + _assertValidNode(element) + if self._context_node is not None: + root = self._context_node + doc = root._doc + elif self._doc is not None: + doc = self._doc + root = doc.getroot() + else: + raise ValueError, "Element is not in this tree." + _assertValidDoc(doc) + _assertValidNode(root) + if element._doc is not doc: + raise ValueError, "Element is not in this tree." + + c_doc = _fakeRootDoc(doc._c_doc, root._c_node) + c_path = tree.xmlGetNodePath(element._c_node) + _destroyFakeDoc(doc._c_doc, c_doc) + if c_path is NULL: + raise MemoryError() + path = funicode(c_path) + tree.xmlFree(c_path) + return path + + def getelementpath(self, _Element element not None): + """getelementpath(self, element) + + Returns a structural, absolute ElementPath expression to find the + element. This path can be used in the .find() method to look up + the element, provided that the elements along the path and their + list of immediate children were not modified in between. + + ElementPath has the advantage over an XPath expression (as returned + by the .getpath() method) that it does not require additional prefix + declarations. It is always self-contained. + """ + cdef _Element root + cdef Py_ssize_t count + _assertValidNode(element) + if element._c_node.type != tree.XML_ELEMENT_NODE: + raise ValueError, "input is not an Element" + if self._context_node is not None: + root = self._context_node + elif self._doc is not None: + root = self._doc.getroot() + else: + raise ValueError, "Element is not in this tree" + _assertValidNode(root) + if element._doc is not root._doc: + raise ValueError, "Element is not in this tree" + + path = [] + c_element = element._c_node + while c_element is not root._c_node: + c_name = c_element.name + c_href = _getNs(c_element) + tag = _namespacedNameFromNsName(c_href, c_name) + if c_href is NULL: + c_href = b'' # no namespace (NULL is wildcard) + # use tag[N] if there are preceding siblings with the same tag + count = 0 + c_node = c_element.prev + while c_node is not NULL: + if c_node.type == tree.XML_ELEMENT_NODE: + if _tagMatches(c_node, c_href, c_name): + count += 1 + c_node = c_node.prev + if count: + tag = f'{tag}[{count+1}]' + else: + # use tag[1] if there are following siblings with the same tag + c_node = c_element.next + while c_node is not NULL: + if c_node.type == tree.XML_ELEMENT_NODE: + if _tagMatches(c_node, c_href, c_name): + tag += '[1]' + break + c_node = c_node.next + + path.append(tag) + c_element = c_element.parent + if c_element is NULL or c_element.type != tree.XML_ELEMENT_NODE: + raise ValueError, "Element is not in this tree." + if not path: + return '.' + path.reverse() + return '/'.join(path) + + def getiterator(self, tag=None, *tags): + """getiterator(self, *tags, tag=None) + + Returns a sequence or iterator of all elements in document order + (depth first pre-order), starting with the root element. + + Can be restricted to find only elements with specific tags, + see `_Element.iter`. + + :deprecated: Note that this method is deprecated as of + ElementTree 1.3 and lxml 2.0. It returns an iterator in + lxml, which diverges from the original ElementTree + behaviour. If you want an efficient iterator, use the + ``tree.iter()`` method instead. You should only use this + method in new code if you require backwards compatibility + with older versions of lxml or ElementTree. + """ + root = self.getroot() + if root is None: + return ITER_EMPTY + if tag is not None: + tags += (tag,) + return root.getiterator(*tags) + + def iter(self, tag=None, *tags): + """iter(self, tag=None, *tags) + + Creates an iterator for the root element. The iterator loops over + all elements in this tree, in document order. Note that siblings + of the root element (comments or processing instructions) are not + returned by the iterator. + + Can be restricted to find only elements with specific tags, + see `_Element.iter`. + """ + root = self.getroot() + if root is None: + return ITER_EMPTY + if tag is not None: + tags += (tag,) + return root.iter(*tags) + + def find(self, path, namespaces=None): + """find(self, path, namespaces=None) + + Finds the first toplevel element with given tag. Same as + ``tree.getroot().find(path)``. + + The optional ``namespaces`` argument accepts a + prefix-to-namespace mapping that allows the usage of XPath + prefixes in the path expression. + """ + self._assertHasRoot() + root = self.getroot() + if _isString(path): + if path[:1] == "/": + path = "." + path + from warnings import warn + warn( + "This search incorrectly ignores the root element, and will be " + "fixed in a future version. If you rely on the current " + f"behaviour, change it to {path!r}", + FutureWarning, stacklevel=1 + ) + return root.find(path, namespaces) + + def findtext(self, path, default=None, namespaces=None): + """findtext(self, path, default=None, namespaces=None) + + Finds the text for the first element matching the ElementPath + expression. Same as getroot().findtext(path) + + The optional ``namespaces`` argument accepts a + prefix-to-namespace mapping that allows the usage of XPath + prefixes in the path expression. + """ + self._assertHasRoot() + root = self.getroot() + if _isString(path): + if path[:1] == "/": + path = "." + path + from warnings import warn + warn( + "This search incorrectly ignores the root element, and will be " + "fixed in a future version. If you rely on the current " + f"behaviour, change it to {path!r}", + FutureWarning, stacklevel=1 + ) + return root.findtext(path, default, namespaces) + + def findall(self, path, namespaces=None): + """findall(self, path, namespaces=None) + + Finds all elements matching the ElementPath expression. Same as + getroot().findall(path). + + The optional ``namespaces`` argument accepts a + prefix-to-namespace mapping that allows the usage of XPath + prefixes in the path expression. + """ + self._assertHasRoot() + root = self.getroot() + if _isString(path): + if path[:1] == "/": + path = "." + path + from warnings import warn + warn( + "This search incorrectly ignores the root element, and will be " + "fixed in a future version. If you rely on the current " + f"behaviour, change it to {path!r}", + FutureWarning, stacklevel=1 + ) + return root.findall(path, namespaces) + + def iterfind(self, path, namespaces=None): + """iterfind(self, path, namespaces=None) + + Iterates over all elements matching the ElementPath expression. + Same as getroot().iterfind(path). + + The optional ``namespaces`` argument accepts a + prefix-to-namespace mapping that allows the usage of XPath + prefixes in the path expression. + """ + self._assertHasRoot() + root = self.getroot() + if _isString(path): + if path[:1] == "/": + path = "." + path + from warnings import warn + warn( + "This search incorrectly ignores the root element, and will be " + "fixed in a future version. If you rely on the current " + f"behaviour, change it to {path!r}", + FutureWarning, stacklevel=1 + ) + return root.iterfind(path, namespaces) + + def xpath(self, _path, *, namespaces=None, extensions=None, + smart_strings=True, **_variables): + """xpath(self, _path, namespaces=None, extensions=None, smart_strings=True, **_variables) + + XPath evaluate in context of document. + + ``namespaces`` is an optional dictionary with prefix to namespace URI + mappings, used by XPath. ``extensions`` defines additional extension + functions. + + Returns a list (nodeset), or bool, float or string. + + In case of a list result, return Element for element nodes, + string for text and attribute values. + + Note: if you are going to apply multiple XPath expressions + against the same document, it is more efficient to use + XPathEvaluator directly. + """ + self._assertHasRoot() + evaluator = XPathDocumentEvaluator(self, namespaces=namespaces, + extensions=extensions, + smart_strings=smart_strings) + return evaluator(_path, **_variables) + + def xslt(self, _xslt, extensions=None, access_control=None, **_kw): + """xslt(self, _xslt, extensions=None, access_control=None, **_kw) + + Transform this document using other document. + + xslt is a tree that should be XSLT + keyword parameters are XSLT transformation parameters. + + Returns the transformed tree. + + Note: if you are going to apply the same XSLT stylesheet against + multiple documents, it is more efficient to use the XSLT + class directly. + """ + self._assertHasRoot() + style = XSLT(_xslt, extensions=extensions, + access_control=access_control) + return style(self, **_kw) + + def relaxng(self, relaxng): + """relaxng(self, relaxng) + + Validate this document using other document. + + The relaxng argument is a tree that should contain a Relax NG schema. + + Returns True or False, depending on whether validation + succeeded. + + Note: if you are going to apply the same Relax NG schema against + multiple documents, it is more efficient to use the RelaxNG + class directly. + """ + self._assertHasRoot() + schema = RelaxNG(relaxng) + return schema.validate(self) + + def xmlschema(self, xmlschema): + """xmlschema(self, xmlschema) + + Validate this document using other document. + + The xmlschema argument is a tree that should contain an XML Schema. + + Returns True or False, depending on whether validation + succeeded. + + Note: If you are going to apply the same XML Schema against + multiple documents, it is more efficient to use the XMLSchema + class directly. + """ + self._assertHasRoot() + schema = XMLSchema(xmlschema) + return schema.validate(self) + + def xinclude(self): + """xinclude(self) + + Process the XInclude nodes in this document and include the + referenced XML fragments. + + There is support for loading files through the file system, HTTP and + FTP. + + Note that XInclude does not support custom resolvers in Python space + due to restrictions of libxml2 <= 2.6.29. + """ + self._assertHasRoot() + XInclude()(self._context_node) + + def write_c14n(self, file, *, bint exclusive=False, bint with_comments=True, + compression=0, inclusive_ns_prefixes=None): + """write_c14n(self, file, exclusive=False, with_comments=True, + compression=0, inclusive_ns_prefixes=None) + + C14N write of document. Always writes UTF-8. + + The ``compression`` option enables GZip compression level 1-9. + + The ``inclusive_ns_prefixes`` should be a list of namespace strings + (i.e. ['xs', 'xsi']) that will be promoted to the top-level element + during exclusive C14N serialisation. This parameter is ignored if + exclusive mode=False. + + If exclusive=True and no list is provided, a namespace will only be + rendered if it is used by the immediate parent or one of its attributes + and its prefix and values have not already been rendered by an ancestor + of the namespace node's parent element. + + NOTE: This method is deprecated as of lxml 4.4 and will be removed in a + future release. Use ``.write(f, method="c14n")`` instead. + """ + self._assertHasRoot() + _assertValidNode(self._context_node) + if compression is None or compression < 0: + compression = 0 + + _tofilelikeC14N(file, self._context_node, exclusive, with_comments, + compression, inclusive_ns_prefixes) + +cdef _ElementTree _elementTreeFactory(_Document doc, _Element context_node): + return _newElementTree(doc, context_node, _ElementTree) + +cdef _ElementTree _newElementTree(_Document doc, _Element context_node, + object baseclass): + cdef _ElementTree result + result = baseclass() + if context_node is None and doc is not None: + context_node = doc.getroot() + if context_node is None: + _assertValidDoc(doc) + result._doc = doc + else: + _assertValidNode(context_node) + result._context_node = context_node + return result + + +@cython.final +@cython.freelist(16) +cdef class _Attrib: + """A dict-like proxy for the ``Element.attrib`` property. + """ + cdef _Element _element + def __cinit__(self, _Element element not None): + _assertValidNode(element) + self._element = element + + # MANIPULATORS + def __setitem__(self, key, value): + _assertValidNode(self._element) + _setAttributeValue(self._element, key, value) + + def __delitem__(self, key): + _assertValidNode(self._element) + _delAttribute(self._element, key) + + def update(self, sequence_or_dict): + _assertValidNode(self._element) + if isinstance(sequence_or_dict, (dict, _Attrib)): + sequence_or_dict = sequence_or_dict.items() + for key, value in sequence_or_dict: + _setAttributeValue(self._element, key, value) + + def pop(self, key, *default): + if len(default) > 1: + raise TypeError, f"pop expected at most 2 arguments, got {len(default)+1}" + _assertValidNode(self._element) + result = _getAttributeValue(self._element, key, None) + if result is None: + if not default: + raise KeyError, key + result = default[0] + else: + _delAttribute(self._element, key) + return result + + def clear(self): + _assertValidNode(self._element) + c_attrs = self._element._c_node.properties + if c_attrs: + self._element._c_node.properties = NULL + tree.xmlFreePropList(c_attrs) + + # ACCESSORS + def __repr__(self): + _assertValidNode(self._element) + return repr(dict( _collectAttributes(self._element._c_node, 3) )) + + def __copy__(self): + _assertValidNode(self._element) + return dict(_collectAttributes(self._element._c_node, 3)) + + def __deepcopy__(self, memo): + _assertValidNode(self._element) + return dict(_collectAttributes(self._element._c_node, 3)) + + def __getitem__(self, key): + _assertValidNode(self._element) + result = _getAttributeValue(self._element, key, None) + if result is None: + raise KeyError, key + return result + + def __bool__(self): + _assertValidNode(self._element) + cdef xmlAttr* c_attr = self._element._c_node.properties + while c_attr is not NULL: + if c_attr.type == tree.XML_ATTRIBUTE_NODE: + return 1 + c_attr = c_attr.next + return 0 + + def __len__(self): + _assertValidNode(self._element) + cdef xmlAttr* c_attr = self._element._c_node.properties + cdef Py_ssize_t c = 0 + while c_attr is not NULL: + if c_attr.type == tree.XML_ATTRIBUTE_NODE: + c += 1 + c_attr = c_attr.next + return c + + def get(self, key, default=None): + _assertValidNode(self._element) + return _getAttributeValue(self._element, key, default) + + def keys(self): + _assertValidNode(self._element) + return _collectAttributes(self._element._c_node, 1) + + def __iter__(self): + _assertValidNode(self._element) + return iter(_collectAttributes(self._element._c_node, 1)) + + def iterkeys(self): + _assertValidNode(self._element) + return iter(_collectAttributes(self._element._c_node, 1)) + + def values(self): + _assertValidNode(self._element) + return _collectAttributes(self._element._c_node, 2) + + def itervalues(self): + _assertValidNode(self._element) + return iter(_collectAttributes(self._element._c_node, 2)) + + def items(self): + _assertValidNode(self._element) + return _collectAttributes(self._element._c_node, 3) + + def iteritems(self): + _assertValidNode(self._element) + return iter(_collectAttributes(self._element._c_node, 3)) + + def has_key(self, key): + _assertValidNode(self._element) + return key in self + + def __contains__(self, key): + _assertValidNode(self._element) + cdef xmlNode* c_node + ns, tag = _getNsTag(key) + c_node = self._element._c_node + c_href = NULL if ns is None else _xcstr(ns) + return 1 if tree.xmlHasNsProp(c_node, _xcstr(tag), c_href) else 0 + + def __richcmp__(self, other, int op): + try: + one = dict(self.items()) + if not isinstance(other, dict): + other = dict(other) + except (TypeError, ValueError): + return NotImplemented + return python.PyObject_RichCompare(one, other, op) + +MutableMapping.register(_Attrib) + + +@cython.final +@cython.internal +cdef class _AttribIterator: + """Attribute iterator - for internal use only! + """ + # XML attributes must not be removed while running! + cdef _Element _node + cdef xmlAttr* _c_attr + cdef int _keysvalues # 1 - keys, 2 - values, 3 - items (key, value) + def __iter__(self): + return self + + def __next__(self): + cdef xmlAttr* c_attr + if self._node is None: + raise StopIteration + c_attr = self._c_attr + while c_attr is not NULL and c_attr.type != tree.XML_ATTRIBUTE_NODE: + c_attr = c_attr.next + if c_attr is NULL: + self._node = None + raise StopIteration + + self._c_attr = c_attr.next + if self._keysvalues == 1: + return _namespacedName(c_attr) + elif self._keysvalues == 2: + return _attributeValue(self._node._c_node, c_attr) + else: + return (_namespacedName(c_attr), + _attributeValue(self._node._c_node, c_attr)) + +cdef object _attributeIteratorFactory(_Element element, int keysvalues): + cdef _AttribIterator attribs + if element._c_node.properties is NULL: + return ITER_EMPTY + attribs = _AttribIterator() + attribs._node = element + attribs._c_attr = element._c_node.properties + attribs._keysvalues = keysvalues + return attribs + + +cdef public class _ElementTagMatcher [ object LxmlElementTagMatcher, + type LxmlElementTagMatcherType ]: + """ + Dead but public. :) + """ + cdef object _pystrings + cdef int _node_type + cdef char* _href + cdef char* _name + cdef _initTagMatch(self, tag): + self._href = NULL + self._name = NULL + if tag is None: + self._node_type = 0 + elif tag is Comment: + self._node_type = tree.XML_COMMENT_NODE + elif tag is ProcessingInstruction: + self._node_type = tree.XML_PI_NODE + elif tag is Entity: + self._node_type = tree.XML_ENTITY_REF_NODE + elif tag is Element: + self._node_type = tree.XML_ELEMENT_NODE + else: + self._node_type = tree.XML_ELEMENT_NODE + self._pystrings = _getNsTag(tag) + if self._pystrings[0] is not None: + self._href = _cstr(self._pystrings[0]) + self._name = _cstr(self._pystrings[1]) + if self._name[0] == c'*' and self._name[1] == c'\0': + self._name = NULL + +cdef public class _ElementIterator(_ElementTagMatcher) [ + object LxmlElementIterator, type LxmlElementIteratorType ]: + """ + Dead but public. :) + """ + # we keep Python references here to control GC + cdef _Element _node + cdef _node_to_node_function _next_element + def __iter__(self): + return self + + cdef void _storeNext(self, _Element node): + cdef xmlNode* c_node + c_node = self._next_element(node._c_node) + while c_node is not NULL and \ + self._node_type != 0 and \ + (self._node_type != c_node.type or + not _tagMatches(c_node, self._href, self._name)): + c_node = self._next_element(c_node) + if c_node is NULL: + self._node = None + else: + # Python ref: + self._node = _elementFactory(node._doc, c_node) + + def __next__(self): + cdef xmlNode* c_node + cdef _Element current_node + if self._node is None: + raise StopIteration + # Python ref: + current_node = self._node + self._storeNext(current_node) + return current_node + +@cython.final +@cython.internal +cdef class _MultiTagMatcher: + """ + Match an xmlNode against a list of tags. + """ + cdef list _py_tags + cdef qname* _cached_tags + cdef size_t _tag_count + cdef size_t _cached_size + cdef _Document _cached_doc + cdef int _node_types + + def __cinit__(self, tags): + self._py_tags = [] + self.initTagMatch(tags) + + def __dealloc__(self): + self._clear() + + cdef bint rejectsAll(self) noexcept: + return not self._tag_count and not self._node_types + + cdef bint rejectsAllAttributes(self) noexcept: + return not self._tag_count + + cdef bint matchesType(self, int node_type) noexcept: + if node_type == tree.XML_ELEMENT_NODE and self._tag_count: + return True + return self._node_types & (1 << node_type) + + cdef void _clear(self) noexcept: + cdef size_t i, count + count = self._tag_count + self._tag_count = 0 + if self._cached_tags: + for i in range(count): + cpython.ref.Py_XDECREF(self._cached_tags[i].href) + python.lxml_free(self._cached_tags) + self._cached_tags = NULL + + cdef initTagMatch(self, tags): + self._cached_doc = None + del self._py_tags[:] + self._clear() + if tags is None or tags == (): + # no selection in tags argument => match anything + self._node_types = ( + 1 << tree.XML_COMMENT_NODE | + 1 << tree.XML_PI_NODE | + 1 << tree.XML_ENTITY_REF_NODE | + 1 << tree.XML_ELEMENT_NODE) + else: + self._node_types = 0 + self._storeTags(tags, set()) + + cdef _storeTags(self, tag, set seen): + if tag is Comment: + self._node_types |= 1 << tree.XML_COMMENT_NODE + elif tag is ProcessingInstruction: + self._node_types |= 1 << tree.XML_PI_NODE + elif tag is Entity: + self._node_types |= 1 << tree.XML_ENTITY_REF_NODE + elif tag is Element: + self._node_types |= 1 << tree.XML_ELEMENT_NODE + elif python._isString(tag): + if tag in seen: + return + seen.add(tag) + if tag in ('*', '{*}*'): + self._node_types |= 1 << tree.XML_ELEMENT_NODE + else: + href, name = _getNsTag(tag) + if name == b'*': + name = None + if href is None: + href = b'' # no namespace + elif href == b'*': + href = None # wildcard: any namespace, including none + self._py_tags.append((href, name)) + elif isinstance(tag, QName): + self._storeTags(tag.text, seen) + else: + # support a sequence of tags + for item in tag: + self._storeTags(item, seen) + + cdef inline int cacheTags(self, _Document doc, bint force_into_dict=False) except -1: + """ + Look up the tag names in the doc dict to enable string pointer comparisons. + """ + cdef size_t dict_size = tree.xmlDictSize(doc._c_doc.dict) + if doc is self._cached_doc and dict_size == self._cached_size: + # doc and dict didn't change => names already cached + return 0 + self._tag_count = 0 + if not self._py_tags: + self._cached_doc = doc + self._cached_size = dict_size + return 0 + if not self._cached_tags: + self._cached_tags = python.lxml_malloc(len(self._py_tags), sizeof(qname)) + if not self._cached_tags: + self._cached_doc = None + raise MemoryError() + self._tag_count = _mapTagsToQnameMatchArray( + doc._c_doc, self._py_tags, self._cached_tags, force_into_dict) + self._cached_doc = doc + self._cached_size = dict_size + return 0 + + cdef inline bint matches(self, xmlNode* c_node) noexcept: + cdef qname* c_qname + if self._node_types & (1 << c_node.type): + return True + elif c_node.type == tree.XML_ELEMENT_NODE: + for c_qname in self._cached_tags[:self._tag_count]: + if _tagMatchesExactly(c_node, c_qname): + return True + return False + + cdef inline bint matchesNsTag(self, const_xmlChar* c_href, + const_xmlChar* c_name) noexcept: + cdef qname* c_qname + if self._node_types & (1 << tree.XML_ELEMENT_NODE): + return True + for c_qname in self._cached_tags[:self._tag_count]: + if _nsTagMatchesExactly(c_href, c_name, c_qname): + return True + return False + + cdef inline bint matchesAttribute(self, xmlAttr* c_attr) noexcept: + """Attribute matches differ from Element matches in that they do + not care about node types. + """ + cdef qname* c_qname + for c_qname in self._cached_tags[:self._tag_count]: + if _tagMatchesExactly(c_attr, c_qname): + return True + return False + +cdef class _ElementMatchIterator: + cdef _Element _node + cdef _node_to_node_function _next_element + cdef _MultiTagMatcher _matcher + + @cython.final + cdef _initTagMatcher(self, tags): + self._matcher = _MultiTagMatcher.__new__(_MultiTagMatcher, tags) + + def __iter__(self): + return self + + @cython.final + cdef int _storeNext(self, _Element node) except -1: + self._matcher.cacheTags(node._doc) + c_node = self._next_element(node._c_node) + while c_node is not NULL and not self._matcher.matches(c_node): + c_node = self._next_element(c_node) + # store Python ref to next node to make sure it's kept alive + self._node = _elementFactory(node._doc, c_node) if c_node is not NULL else None + return 0 + + def __next__(self): + cdef _Element current_node = self._node + if current_node is None: + raise StopIteration + self._storeNext(current_node) + return current_node + +cdef class ElementChildIterator(_ElementMatchIterator): + """ElementChildIterator(self, node, tag=None, reversed=False) + Iterates over the children of an element. + """ + def __cinit__(self, _Element node not None, tag=None, *, bint reversed=False): + cdef xmlNode* c_node + _assertValidNode(node) + self._initTagMatcher(tag) + if reversed: + c_node = _findChildBackwards(node._c_node, 0) + self._next_element = _previousElement + else: + c_node = _findChildForwards(node._c_node, 0) + self._next_element = _nextElement + self._matcher.cacheTags(node._doc) + while c_node is not NULL and not self._matcher.matches(c_node): + c_node = self._next_element(c_node) + # store Python ref to next node to make sure it's kept alive + self._node = _elementFactory(node._doc, c_node) if c_node is not NULL else None + +cdef class SiblingsIterator(_ElementMatchIterator): + """SiblingsIterator(self, node, tag=None, preceding=False) + Iterates over the siblings of an element. + + You can pass the boolean keyword ``preceding`` to specify the direction. + """ + def __cinit__(self, _Element node not None, tag=None, *, bint preceding=False): + _assertValidNode(node) + self._initTagMatcher(tag) + if preceding: + self._next_element = _previousElement + else: + self._next_element = _nextElement + self._storeNext(node) + +cdef class AncestorsIterator(_ElementMatchIterator): + """AncestorsIterator(self, node, tag=None) + Iterates over the ancestors of an element (from parent to parent). + """ + def __cinit__(self, _Element node not None, tag=None): + _assertValidNode(node) + self._initTagMatcher(tag) + self._next_element = _parentElement + self._storeNext(node) + +cdef class ElementDepthFirstIterator: + """ElementDepthFirstIterator(self, node, tag=None, inclusive=True) + Iterates over an element and its sub-elements in document order (depth + first pre-order). + + Note that this also includes comments, entities and processing + instructions. To filter them out, check if the ``tag`` property + of the returned element is a string (i.e. not None and not a + factory function), or pass the ``Element`` factory for the ``tag`` + argument to receive only Elements. + + If the optional ``tag`` argument is not None, the iterator returns only + the elements that match the respective name and namespace. + + The optional boolean argument 'inclusive' defaults to True and can be set + to False to exclude the start element itself. + + Note that the behaviour of this iterator is completely undefined if the + tree it traverses is modified during iteration. + """ + # we keep Python references here to control GC + # keep the next Element after the one we return, and the (s)top node + cdef _Element _next_node + cdef _Element _top_node + cdef _MultiTagMatcher _matcher + def __cinit__(self, _Element node not None, tag=None, *, bint inclusive=True): + _assertValidNode(node) + self._top_node = node + self._next_node = node + self._matcher = _MultiTagMatcher.__new__(_MultiTagMatcher, tag) + self._matcher.cacheTags(node._doc) + if not inclusive or not self._matcher.matches(node._c_node): + # find start node (this cannot raise StopIteration, self._next_node != None) + next(self) + + def __iter__(self): + return self + + def __next__(self): + cdef xmlNode* c_node + cdef _Element current_node = self._next_node + if current_node is None: + raise StopIteration + c_node = current_node._c_node + self._matcher.cacheTags(current_node._doc) + if not self._matcher._tag_count: + # no tag name was found in the dict => not in document either + # try to match by node type + c_node = self._nextNodeAnyTag(c_node) + else: + c_node = self._nextNodeMatchTag(c_node) + if c_node is NULL: + self._next_node = None + else: + self._next_node = _elementFactory(current_node._doc, c_node) + return current_node + + @cython.final + cdef xmlNode* _nextNodeAnyTag(self, xmlNode* c_node) noexcept: + cdef int node_types = self._matcher._node_types + if not node_types: + return NULL + tree.BEGIN_FOR_EACH_ELEMENT_FROM(self._top_node._c_node, c_node, 0) + if node_types & (1 << c_node.type): + return c_node + tree.END_FOR_EACH_ELEMENT_FROM(c_node) + return NULL + + @cython.final + cdef xmlNode* _nextNodeMatchTag(self, xmlNode* c_node) noexcept: + tree.BEGIN_FOR_EACH_ELEMENT_FROM(self._top_node._c_node, c_node, 0) + if self._matcher.matches(c_node): + return c_node + tree.END_FOR_EACH_ELEMENT_FROM(c_node) + return NULL + + +cdef class ElementTextIterator: + """ElementTextIterator(self, element, tag=None, with_tail=True) + Iterates over the text content of a subtree. + + You can pass the ``tag`` keyword argument to restrict text content to a + specific tag name. + + You can set the ``with_tail`` keyword argument to ``False`` to skip over + tail text (e.g. if you know that it's only whitespace from pretty-printing). + """ + cdef object _events + cdef _Element _start_element + def __cinit__(self, _Element element not None, tag=None, *, bint with_tail=True): + _assertValidNode(element) + if with_tail: + events = ("start", "comment", "pi", "end") + else: + events = ("start",) + self._start_element = element + self._events = iterwalk(element, events=events, tag=tag) + + def __iter__(self): + return self + + def __next__(self): + cdef _Element element + result = None + while result is None: + event, element = next(self._events) # raises StopIteration + if event == "start": + result = element.text + elif element is not self._start_element: + result = element.tail + return result + + +cdef xmlNode* _createElement(xmlDoc* c_doc, object name_utf) except NULL: + cdef xmlNode* c_node + c_node = tree.xmlNewDocNode(c_doc, NULL, _xcstr(name_utf), NULL) + return c_node + +cdef xmlNode* _createComment(xmlDoc* c_doc, const_xmlChar* text) noexcept: + cdef xmlNode* c_node + c_node = tree.xmlNewDocComment(c_doc, text) + return c_node + +cdef xmlNode* _createPI(xmlDoc* c_doc, const_xmlChar* target, const_xmlChar* text) noexcept: + cdef xmlNode* c_node + c_node = tree.xmlNewDocPI(c_doc, target, text) + return c_node + +cdef xmlNode* _createEntity(xmlDoc* c_doc, const_xmlChar* name) noexcept: + cdef xmlNode* c_node + c_node = tree.xmlNewReference(c_doc, name) + return c_node + +# module-level API for ElementTree + +from abc import ABC + +class Element(ABC): + """Element(_tag, attrib=None, nsmap=None, **_extra) + + Element factory, as a class. + + An instance of this class is an object implementing the + Element interface. + + >>> element = Element("test") + >>> type(element) + + >>> isinstance(element, Element) + True + >>> issubclass(_Element, Element) + True + + Also look at the `_Element.makeelement()` and + `_BaseParser.makeelement()` methods, which provide a faster way to + create an Element within a specific document or parser context. + """ + def __new__(cls, _tag, attrib=None, nsmap=None, **_extra): + return _makeElement(_tag, NULL, None, None, None, None, + attrib, nsmap, _extra) + +# Register _Element as a virtual subclass of Element +Element.register(_Element) + + +def Comment(text=None): + """Comment(text=None) + + Comment element factory. This factory function creates a special element that will + be serialized as an XML comment. + """ + cdef _Document doc + cdef xmlNode* c_node + cdef xmlDoc* c_doc + + if text is None: + text = b'' + else: + text = _utf8(text) + if b'--' in text or text.endswith(b'-'): + raise ValueError("Comment may not contain '--' or end with '-'") + + c_doc = _newXMLDoc() + doc = _documentFactory(c_doc, None) + c_node = _createComment(c_doc, _xcstr(text)) + tree.xmlAddChild(c_doc, c_node) + return _elementFactory(doc, c_node) + + +def ProcessingInstruction(target, text=None): + """ProcessingInstruction(target, text=None) + + ProcessingInstruction element factory. This factory function creates a + special element that will be serialized as an XML processing instruction. + """ + cdef _Document doc + cdef xmlNode* c_node + cdef xmlDoc* c_doc + + target = _utf8(target) + _tagValidOrRaise(target) + if target.lower() == b'xml': + raise ValueError, f"Invalid PI name '{target}'" + + if text is None: + text = b'' + else: + text = _utf8(text) + if b'?>' in text: + raise ValueError, "PI text must not contain '?>'" + + c_doc = _newXMLDoc() + doc = _documentFactory(c_doc, None) + c_node = _createPI(c_doc, _xcstr(target), _xcstr(text)) + tree.xmlAddChild(c_doc, c_node) + return _elementFactory(doc, c_node) + +PI = ProcessingInstruction + + +cdef class CDATA: + """CDATA(data) + + CDATA factory. This factory creates an opaque data object that + can be used to set Element text. The usual way to use it is:: + + >>> el = Element('content') + >>> el.text = CDATA('a string') + + >>> print(el.text) + a string + >>> print(tostring(el, encoding="unicode")) + + """ + cdef bytes _utf8_data + def __cinit__(self, data): + self._utf8_data = _utf8(data) + + +def Entity(name): + """Entity(name) + + Entity factory. This factory function creates a special element + that will be serialized as an XML entity reference or character + reference. Note, however, that entities will not be automatically + declared in the document. A document that uses entity references + requires a DTD to define the entities. + """ + cdef _Document doc + cdef xmlNode* c_node + cdef xmlDoc* c_doc + name_utf = _utf8(name) + c_name = _xcstr(name_utf) + if c_name[0] == c'#': + if not _characterReferenceIsValid(c_name + 1): + raise ValueError, f"Invalid character reference: '{name}'" + elif not _xmlNameIsValid(c_name): + raise ValueError, f"Invalid entity reference: '{name}'" + c_doc = _newXMLDoc() + doc = _documentFactory(c_doc, None) + c_node = _createEntity(c_doc, c_name) + tree.xmlAddChild(c_doc, c_node) + return _elementFactory(doc, c_node) + + +def SubElement(_Element _parent not None, _tag, + attrib=None, nsmap=None, **_extra): + """SubElement(_parent, _tag, attrib=None, nsmap=None, **_extra) + + Subelement factory. This function creates an element instance, and + appends it to an existing element. + """ + return _makeSubElement(_parent, _tag, None, None, attrib, nsmap, _extra) + +from typing import Generic, TypeVar + +T = TypeVar("T") + +class ElementTree(ABC, Generic[T]): + def __new__(cls, _Element element=None, *, file=None, _BaseParser parser=None): + """ElementTree(element=None, file=None, parser=None) + + ElementTree wrapper class. + """ + cdef xmlNode* c_next + cdef xmlNode* c_node + cdef xmlNode* c_node_copy + cdef xmlDoc* c_doc + cdef _ElementTree etree + cdef _Document doc + + if element is not None: + doc = element._doc + elif file is not None: + try: + doc = _parseDocument(file, parser, None) + except _TargetParserResult as result_container: + return result_container.result + else: + c_doc = _newXMLDoc() + doc = _documentFactory(c_doc, parser) + + return _elementTreeFactory(doc, element) + +# Register _ElementTree as a virtual subclass of ElementTree +ElementTree.register(_ElementTree) + +# Remove "ABC" and typing helpers from module dict +del ABC, Generic, TypeVar, T + +def HTML(text, _BaseParser parser=None, *, base_url=None): + """HTML(text, parser=None, base_url=None) + + Parses an HTML document from a string constant. Returns the root + node (or the result returned by a parser target). This function + can be used to embed "HTML literals" in Python code. + + To override the parser with a different ``HTMLParser`` you can pass it to + the ``parser`` keyword argument. + + The ``base_url`` keyword argument allows to set the original base URL of + the document to support relative Paths when looking up external entities + (DTD, XInclude, ...). + """ + cdef _Document doc + if parser is None: + parser = __GLOBAL_PARSER_CONTEXT.getDefaultParser() + if not isinstance(parser, HTMLParser): + parser = __DEFAULT_HTML_PARSER + try: + doc = _parseMemoryDocument(text, base_url, parser) + return doc.getroot() + except _TargetParserResult as result_container: + return result_container.result + + +def XML(text, _BaseParser parser=None, *, base_url=None): + """XML(text, parser=None, base_url=None) + + Parses an XML document or fragment from a string constant. + Returns the root node (or the result returned by a parser target). + This function can be used to embed "XML literals" in Python code, + like in + + >>> root = XML("") + >>> print(root.tag) + root + + To override the parser with a different ``XMLParser`` you can pass it to + the ``parser`` keyword argument. + + The ``base_url`` keyword argument allows to set the original base URL of + the document to support relative Paths when looking up external entities + (DTD, XInclude, ...). + """ + cdef _Document doc + if parser is None: + parser = __GLOBAL_PARSER_CONTEXT.getDefaultParser() + if not isinstance(parser, XMLParser): + parser = __DEFAULT_XML_PARSER + try: + doc = _parseMemoryDocument(text, base_url, parser) + return doc.getroot() + except _TargetParserResult as result_container: + return result_container.result + + +def fromstring(text, _BaseParser parser=None, *, base_url=None): + """fromstring(text, parser=None, base_url=None) + + Parses an XML document or fragment from a string. Returns the + root node (or the result returned by a parser target). + + To override the default parser with a different parser you can pass it to + the ``parser`` keyword argument. + + The ``base_url`` keyword argument allows to set the original base URL of + the document to support relative Paths when looking up external entities + (DTD, XInclude, ...). + """ + cdef _Document doc + try: + doc = _parseMemoryDocument(text, base_url, parser) + return doc.getroot() + except _TargetParserResult as result_container: + return result_container.result + + +def fromstringlist(strings, _BaseParser parser=None): + """fromstringlist(strings, parser=None) + + Parses an XML document from a sequence of strings. Returns the + root node (or the result returned by a parser target). + + To override the default parser with a different parser you can pass it to + the ``parser`` keyword argument. + """ + cdef _Document doc + if isinstance(strings, (bytes, unicode)): + raise ValueError("passing a single string into fromstringlist() is not" + " efficient, use fromstring() instead") + if parser is None: + parser = __GLOBAL_PARSER_CONTEXT.getDefaultParser() + feed = parser.feed + for data in strings: + feed(data) + return parser.close() + + +def iselement(element): + """iselement(element) + + Checks if an object appears to be a valid element object. + """ + return isinstance(element, _Element) and (<_Element>element)._c_node is not NULL + + +def indent(tree, space=" ", *, Py_ssize_t level=0): + """indent(tree, space=" ", level=0) + + Indent an XML document by inserting newlines and indentation space + after elements. + + *tree* is the ElementTree or Element to modify. The (root) element + itself will not be changed, but the tail text of all elements in its + subtree will be adapted. + + *space* is the whitespace to insert for each indentation level, two + space characters by default. + + *level* is the initial indentation level. Setting this to a higher + value than 0 can be used for indenting subtrees that are more deeply + nested inside of a document. + """ + root = _rootNodeOrRaise(tree) + if level < 0: + raise ValueError(f"Initial indentation level must be >= 0, got {level}") + if _hasChild(root._c_node): + space = _utf8(space) + indent = b"\n" + level * space + _indent_children(root._c_node, 1, space, [indent, indent + space]) + + +cdef int _indent_children(xmlNode* c_node, Py_ssize_t level, bytes one_space, list indentations) except -1: + # Reuse indentation strings for speed. + if len(indentations) <= level: + indentations.append(indentations[-1] + one_space) + + # Start a new indentation level for the first child. + child_indentation = indentations[level] + if not _hasNonWhitespaceText(c_node): + _setNodeText(c_node, child_indentation) + + # Recursively indent all children. + cdef xmlNode* c_child = _findChildForwards(c_node, 0) + while c_child is not NULL: + if _hasChild(c_child): + _indent_children(c_child, level+1, one_space, indentations) + c_next_child = _nextElement(c_child) + if not _hasNonWhitespaceTail(c_child): + if c_next_child is NULL: + # Dedent after the last child. + child_indentation = indentations[level-1] + _setTailText(c_child, child_indentation) + c_child = c_next_child + return 0 + + +def dump(_Element elem not None, *, bint pretty_print=True, bint with_tail=True): + """dump(elem, pretty_print=True, with_tail=True) + + Writes an element tree or element structure to sys.stdout. This function + should be used for debugging only. + """ + xml = tostring(elem, pretty_print=pretty_print, with_tail=with_tail, encoding='unicode') + if not pretty_print: + xml += '\n' + sys.stdout.write(xml) + + +def tostring(element_or_tree, *, encoding=None, method="xml", + xml_declaration=None, bint pretty_print=False, bint with_tail=True, + standalone=None, doctype=None, + # method='c14n' + bint exclusive=False, inclusive_ns_prefixes=None, + # method='c14n2' + bint with_comments=True, bint strip_text=False, + ): + """tostring(element_or_tree, encoding=None, method="xml", + xml_declaration=None, pretty_print=False, with_tail=True, + standalone=None, doctype=None, + exclusive=False, inclusive_ns_prefixes=None, + with_comments=True, strip_text=False, + ) + + Serialize an element to an encoded string representation of its XML + tree. + + Defaults to ASCII encoding without XML declaration. This + behaviour can be configured with the keyword arguments 'encoding' + (string) and 'xml_declaration' (bool). Note that changing the + encoding to a non UTF-8 compatible encoding will enable a + declaration by default. + + You can also serialise to a Unicode string without declaration by + passing the name ``'unicode'`` as encoding (or the ``str`` function + in Py3 or ``unicode`` in Py2). This changes the return value from + a byte string to an unencoded unicode string. + + The keyword argument 'pretty_print' (bool) enables formatted XML. + + The keyword argument 'method' selects the output method: 'xml', + 'html', plain 'text' (text content without tags), 'c14n' or 'c14n2'. + Default is 'xml'. + + With ``method="c14n"`` (C14N version 1), the options ``exclusive``, + ``with_comments`` and ``inclusive_ns_prefixes`` request exclusive + C14N, include comments, and list the inclusive prefixes respectively. + + With ``method="c14n2"`` (C14N version 2), the ``with_comments`` and + ``strip_text`` options control the output of comments and text space + according to C14N 2.0. + + Passing a boolean value to the ``standalone`` option will output + an XML declaration with the corresponding ``standalone`` flag. + + The ``doctype`` option allows passing in a plain string that will + be serialised before the XML tree. Note that passing in non + well-formed content here will make the XML output non well-formed. + Also, an existing doctype in the document tree will not be removed + when serialising an ElementTree instance. + + You can prevent the tail text of the element from being serialised + by passing the boolean ``with_tail`` option. This has no impact + on the tail text of children, which will always be serialised. + """ + cdef bint write_declaration + cdef int is_standalone + # C14N serialisation + if method in ('c14n', 'c14n2'): + if encoding is not None: + raise ValueError("Cannot specify encoding with C14N") + if xml_declaration: + raise ValueError("Cannot enable XML declaration in C14N") + if method == 'c14n': + return _tostringC14N(element_or_tree, exclusive, with_comments, inclusive_ns_prefixes) + else: + out = BytesIO() + target = C14NWriterTarget( + utf8_writer(out).write, + with_comments=with_comments, strip_text=strip_text) + _tree_to_target(element_or_tree, target) + return out.getvalue() + if not with_comments: + raise ValueError("Can only discard comments in C14N serialisation") + if strip_text: + raise ValueError("Can only strip text in C14N 2.0 serialisation") + if encoding is unicode or (encoding is not None and encoding.lower() == 'unicode'): + if xml_declaration: + raise ValueError, \ + "Serialisation to unicode must not request an XML declaration" + write_declaration = 0 + encoding = unicode + elif xml_declaration is None: + # by default, write an XML declaration only for non-standard encodings + write_declaration = encoding is not None and encoding.upper() not in \ + ('ASCII', 'UTF-8', 'UTF8', 'US-ASCII') + else: + write_declaration = xml_declaration + if encoding is None: + encoding = 'ASCII' + if standalone is None: + is_standalone = -1 + elif standalone: + write_declaration = 1 + is_standalone = 1 + else: + write_declaration = 1 + is_standalone = 0 + + if isinstance(element_or_tree, _Element): + return _tostring(<_Element>element_or_tree, encoding, doctype, method, + write_declaration, 0, pretty_print, with_tail, + is_standalone) + elif isinstance(element_or_tree, _ElementTree): + return _tostring((<_ElementTree>element_or_tree)._context_node, + encoding, doctype, method, write_declaration, 1, + pretty_print, with_tail, is_standalone) + else: + raise TypeError, f"Type '{python._fqtypename(element_or_tree).decode('utf8')}' cannot be serialized." + + + +def tostringlist(element_or_tree, *args, **kwargs): + """tostringlist(element_or_tree, *args, **kwargs) + + Serialize an element to an encoded string representation of its XML + tree, stored in a list of partial strings. + + This is purely for ElementTree 1.3 compatibility. The result is a + single string wrapped in a list. + """ + return [tostring(element_or_tree, *args, **kwargs)] + + +def tounicode(element_or_tree, *, method="xml", bint pretty_print=False, + bint with_tail=True, doctype=None): + """tounicode(element_or_tree, method="xml", pretty_print=False, + with_tail=True, doctype=None) + + Serialize an element to the Python unicode representation of its XML + tree. + + :deprecated: use ``tostring(el, encoding='unicode')`` instead. + + Note that the result does not carry an XML encoding declaration and is + therefore not necessarily suited for serialization to byte streams without + further treatment. + + The boolean keyword argument 'pretty_print' enables formatted XML. + + The keyword argument 'method' selects the output method: 'xml', + 'html' or plain 'text'. + + You can prevent the tail text of the element from being serialised + by passing the boolean ``with_tail`` option. This has no impact + on the tail text of children, which will always be serialised. + """ + if isinstance(element_or_tree, _Element): + return _tostring(<_Element>element_or_tree, unicode, doctype, method, + 0, 0, pretty_print, with_tail, -1) + elif isinstance(element_or_tree, _ElementTree): + return _tostring((<_ElementTree>element_or_tree)._context_node, + unicode, doctype, method, 0, 1, pretty_print, + with_tail, -1) + else: + raise TypeError, f"Type '{type(element_or_tree)}' cannot be serialized." + + +def parse(source, _BaseParser parser=None, *, base_url=None): + """parse(source, parser=None, base_url=None) + + Return an ElementTree object loaded with source elements. If no parser + is provided as second argument, the default parser is used. + + The ``source`` can be any of the following: + + - a file name/path + - a file object + - a file-like object + - a URL using the HTTP or FTP protocol + + To parse from a string, use the ``fromstring()`` function instead. + + Note that it is generally faster to parse from a file path or URL + than from an open file object or file-like object. Transparent + decompression from gzip compressed sources is supported (unless + explicitly disabled in libxml2). + + The ``base_url`` keyword allows setting a URL for the document + when parsing from a file-like object. This is needed when looking + up external entities (DTD, XInclude, ...) with relative paths. + """ + cdef _Document doc + try: + doc = _parseDocument(source, parser, base_url) + return _elementTreeFactory(doc, None) + except _TargetParserResult as result_container: + return result_container.result + + +def adopt_external_document(capsule, _BaseParser parser=None): + """adopt_external_document(capsule, parser=None) + + Unpack a libxml2 document pointer from a PyCapsule and wrap it in an + lxml ElementTree object. + + This allows external libraries to build XML/HTML trees using libxml2 + and then pass them efficiently into lxml for further processing. + + If a ``parser`` is provided, it will be used for configuring the + lxml document. No parsing will be done. + + The capsule must have the name ``"libxml2:xmlDoc"`` and its pointer + value must reference a correct libxml2 document of type ``xmlDoc*``. + The creator of the capsule must take care to correctly clean up the + document using an appropriate capsule destructor. By default, the + libxml2 document will be copied to let lxml safely own the memory + of the internal tree that it uses. + + If the capsule context is non-NULL, it must point to a C string that + can be compared using ``strcmp()``. If the context string equals + ``"destructor:xmlFreeDoc"``, the libxml2 document will not be copied + but the capsule invalidated instead by clearing its destructor and + name. That way, lxml takes ownership of the libxml2 document in memory + without creating a copy first, and the capsule destructor will not be + called. The document will then eventually be cleaned up by lxml using + the libxml2 API function ``xmlFreeDoc()`` once it is no longer used. + + If no copy is made, later modifications of the tree outside of lxml + should not be attempted after transferring the ownership. + """ + cdef xmlDoc* c_doc + cdef bint is_owned = False + c_doc = python.lxml_unpack_xmldoc_capsule(capsule, &is_owned) + doc = _adoptForeignDoc(c_doc, parser, is_owned) + return _elementTreeFactory(doc, None) + + +################################################################################ +# Include submodules + +include "readonlytree.pxi" # Read-only implementation of Element proxies +include "classlookup.pxi" # Element class lookup mechanisms +include "nsclasses.pxi" # Namespace implementation and registry +include "docloader.pxi" # Support for custom document loaders +include "parser.pxi" # XML and HTML parsers +include "saxparser.pxi" # SAX-like Parser interface and tree builder +include "parsertarget.pxi" # ET Parser target +include "serializer.pxi" # XML output functions +include "iterparse.pxi" # incremental XML parsing +include "xmlid.pxi" # XMLID and IDDict +include "xinclude.pxi" # XInclude +include "cleanup.pxi" # Cleanup and recursive element removal functions + + +################################################################################ +# Include submodules for XPath and XSLT + +include "extensions.pxi" # XPath/XSLT extension functions +include "xpath.pxi" # XPath evaluation +include "xslt.pxi" # XSL transformations +include "xsltext.pxi" # XSL extension elements + + +################################################################################ +# Validation + +cdef class DocumentInvalid(LxmlError): + """Validation error. + + Raised by all document validators when their ``assertValid(tree)`` + method fails. + """ + + +cdef class _Validator: + "Base class for XML validators." + cdef _ErrorLog _error_log + def __cinit__(self): + self._error_log = _ErrorLog() + + def validate(self, etree): + """validate(self, etree) + + Validate the document using this schema. + + Returns true if document is valid, false if not. + """ + return self(etree) + + def assertValid(self, etree): + """assertValid(self, etree) + + Raises `DocumentInvalid` if the document does not comply with the schema. + """ + if not self(etree): + raise DocumentInvalid(self._error_log._buildExceptionMessage( + "Document does not comply with schema"), + self._error_log) + + def assert_(self, etree): + """assert_(self, etree) + + Raises `AssertionError` if the document does not comply with the schema. + """ + if not self(etree): + raise AssertionError, self._error_log._buildExceptionMessage( + "Document does not comply with schema") + + cpdef _append_log_message(self, int domain, int type, int level, int line, + message, filename): + self._error_log._receiveGeneric(domain, type, level, line, message, + filename) + + cpdef _clear_error_log(self): + self._error_log.clear() + + @property + def error_log(self): + """The log of validation errors and warnings.""" + assert self._error_log is not None, "XPath evaluator not initialised" + return self._error_log.copy() + +include "dtd.pxi" # DTD +include "relaxng.pxi" # RelaxNG +include "xmlschema.pxi" # XMLSchema +include "schematron.pxi" # Schematron (requires libxml2 2.6.21+) + +################################################################################ +# Public C API + +include "public-api.pxi" + +################################################################################ +# Other stuff + +include "debug.pxi" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/etree_api.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/etree_api.h new file mode 100644 index 0000000000000000000000000000000000000000..bbbb86b5ead5938371cd1d6dd966889a18a57dec --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/etree_api.h @@ -0,0 +1,204 @@ +/* Generated by Cython 3.1.4 */ + +#ifndef __PYX_HAVE_API__lxml__etree +#define __PYX_HAVE_API__lxml__etree +#ifdef __MINGW64__ +#define MS_WIN64 +#endif +#include "Python.h" +#include "etree.h" + +static struct LxmlElement *(*__pyx_api_f_4lxml_5etree_deepcopyNodeToDocument)(struct LxmlDocument *, xmlNode *) = 0; +#define deepcopyNodeToDocument __pyx_api_f_4lxml_5etree_deepcopyNodeToDocument +static struct LxmlElementTree *(*__pyx_api_f_4lxml_5etree_elementTreeFactory)(struct LxmlElement *) = 0; +#define elementTreeFactory __pyx_api_f_4lxml_5etree_elementTreeFactory +static struct LxmlElementTree *(*__pyx_api_f_4lxml_5etree_newElementTree)(struct LxmlElement *, PyObject *) = 0; +#define newElementTree __pyx_api_f_4lxml_5etree_newElementTree +static struct LxmlElementTree *(*__pyx_api_f_4lxml_5etree_adoptExternalDocument)(xmlDoc *, PyObject *, int) = 0; +#define adoptExternalDocument __pyx_api_f_4lxml_5etree_adoptExternalDocument +static struct LxmlElement *(*__pyx_api_f_4lxml_5etree_elementFactory)(struct LxmlDocument *, xmlNode *) = 0; +#define elementFactory __pyx_api_f_4lxml_5etree_elementFactory +static struct LxmlElement *(*__pyx_api_f_4lxml_5etree_makeElement)(PyObject *, struct LxmlDocument *, PyObject *, PyObject *, PyObject *, PyObject *, PyObject *) = 0; +#define makeElement __pyx_api_f_4lxml_5etree_makeElement +static struct LxmlElement *(*__pyx_api_f_4lxml_5etree_makeSubElement)(struct LxmlElement *, PyObject *, PyObject *, PyObject *, PyObject *, PyObject *) = 0; +#define makeSubElement __pyx_api_f_4lxml_5etree_makeSubElement +static void (*__pyx_api_f_4lxml_5etree_setElementClassLookupFunction)(_element_class_lookup_function, PyObject *) = 0; +#define setElementClassLookupFunction __pyx_api_f_4lxml_5etree_setElementClassLookupFunction +static PyObject *(*__pyx_api_f_4lxml_5etree_lookupDefaultElementClass)(PyObject *, PyObject *, xmlNode *) = 0; +#define lookupDefaultElementClass __pyx_api_f_4lxml_5etree_lookupDefaultElementClass +static PyObject *(*__pyx_api_f_4lxml_5etree_lookupNamespaceElementClass)(PyObject *, PyObject *, xmlNode *) = 0; +#define lookupNamespaceElementClass __pyx_api_f_4lxml_5etree_lookupNamespaceElementClass +static PyObject *(*__pyx_api_f_4lxml_5etree_callLookupFallback)(struct LxmlFallbackElementClassLookup *, struct LxmlDocument *, xmlNode *) = 0; +#define callLookupFallback __pyx_api_f_4lxml_5etree_callLookupFallback +static int (*__pyx_api_f_4lxml_5etree_tagMatches)(xmlNode *, const xmlChar *, const xmlChar *) = 0; +#define tagMatches __pyx_api_f_4lxml_5etree_tagMatches +static struct LxmlDocument *(*__pyx_api_f_4lxml_5etree_documentOrRaise)(PyObject *) = 0; +#define documentOrRaise __pyx_api_f_4lxml_5etree_documentOrRaise +static struct LxmlElement *(*__pyx_api_f_4lxml_5etree_rootNodeOrRaise)(PyObject *) = 0; +#define rootNodeOrRaise __pyx_api_f_4lxml_5etree_rootNodeOrRaise +static int (*__pyx_api_f_4lxml_5etree_hasText)(xmlNode *) = 0; +#define hasText __pyx_api_f_4lxml_5etree_hasText +static int (*__pyx_api_f_4lxml_5etree_hasTail)(xmlNode *) = 0; +#define hasTail __pyx_api_f_4lxml_5etree_hasTail +static PyObject *(*__pyx_api_f_4lxml_5etree_textOf)(xmlNode *) = 0; +#define textOf __pyx_api_f_4lxml_5etree_textOf +static PyObject *(*__pyx_api_f_4lxml_5etree_tailOf)(xmlNode *) = 0; +#define tailOf __pyx_api_f_4lxml_5etree_tailOf +static int (*__pyx_api_f_4lxml_5etree_setNodeText)(xmlNode *, PyObject *) = 0; +#define setNodeText __pyx_api_f_4lxml_5etree_setNodeText +static int (*__pyx_api_f_4lxml_5etree_setTailText)(xmlNode *, PyObject *) = 0; +#define setTailText __pyx_api_f_4lxml_5etree_setTailText +static PyObject *(*__pyx_api_f_4lxml_5etree_attributeValue)(xmlNode *, xmlAttr *) = 0; +#define attributeValue __pyx_api_f_4lxml_5etree_attributeValue +static PyObject *(*__pyx_api_f_4lxml_5etree_attributeValueFromNsName)(xmlNode *, const xmlChar *, const xmlChar *) = 0; +#define attributeValueFromNsName __pyx_api_f_4lxml_5etree_attributeValueFromNsName +static PyObject *(*__pyx_api_f_4lxml_5etree_getAttributeValue)(struct LxmlElement *, PyObject *, PyObject *) = 0; +#define getAttributeValue __pyx_api_f_4lxml_5etree_getAttributeValue +static PyObject *(*__pyx_api_f_4lxml_5etree_iterattributes)(struct LxmlElement *, int) = 0; +#define iterattributes __pyx_api_f_4lxml_5etree_iterattributes +static PyObject *(*__pyx_api_f_4lxml_5etree_collectAttributes)(xmlNode *, int) = 0; +#define collectAttributes __pyx_api_f_4lxml_5etree_collectAttributes +static int (*__pyx_api_f_4lxml_5etree_setAttributeValue)(struct LxmlElement *, PyObject *, PyObject *) = 0; +#define setAttributeValue __pyx_api_f_4lxml_5etree_setAttributeValue +static int (*__pyx_api_f_4lxml_5etree_delAttribute)(struct LxmlElement *, PyObject *) = 0; +#define delAttribute __pyx_api_f_4lxml_5etree_delAttribute +static int (*__pyx_api_f_4lxml_5etree_delAttributeFromNsName)(xmlNode *, const xmlChar *, const xmlChar *) = 0; +#define delAttributeFromNsName __pyx_api_f_4lxml_5etree_delAttributeFromNsName +static int (*__pyx_api_f_4lxml_5etree_hasChild)(xmlNode *) = 0; +#define hasChild __pyx_api_f_4lxml_5etree_hasChild +static xmlNode *(*__pyx_api_f_4lxml_5etree_findChild)(xmlNode *, Py_ssize_t) = 0; +#define findChild __pyx_api_f_4lxml_5etree_findChild +static xmlNode *(*__pyx_api_f_4lxml_5etree_findChildForwards)(xmlNode *, Py_ssize_t) = 0; +#define findChildForwards __pyx_api_f_4lxml_5etree_findChildForwards +static xmlNode *(*__pyx_api_f_4lxml_5etree_findChildBackwards)(xmlNode *, Py_ssize_t) = 0; +#define findChildBackwards __pyx_api_f_4lxml_5etree_findChildBackwards +static xmlNode *(*__pyx_api_f_4lxml_5etree_nextElement)(xmlNode *) = 0; +#define nextElement __pyx_api_f_4lxml_5etree_nextElement +static xmlNode *(*__pyx_api_f_4lxml_5etree_previousElement)(xmlNode *) = 0; +#define previousElement __pyx_api_f_4lxml_5etree_previousElement +static void (*__pyx_api_f_4lxml_5etree_appendChild)(struct LxmlElement *, struct LxmlElement *) = 0; +#define appendChild __pyx_api_f_4lxml_5etree_appendChild +static int (*__pyx_api_f_4lxml_5etree_appendChildToElement)(struct LxmlElement *, struct LxmlElement *) = 0; +#define appendChildToElement __pyx_api_f_4lxml_5etree_appendChildToElement +static PyObject *(*__pyx_api_f_4lxml_5etree_pyunicode)(const xmlChar *) = 0; +#define pyunicode __pyx_api_f_4lxml_5etree_pyunicode +static PyObject *(*__pyx_api_f_4lxml_5etree_utf8)(PyObject *) = 0; +#define utf8 __pyx_api_f_4lxml_5etree_utf8 +static PyObject *(*__pyx_api_f_4lxml_5etree_getNsTag)(PyObject *) = 0; +#define getNsTag __pyx_api_f_4lxml_5etree_getNsTag +static PyObject *(*__pyx_api_f_4lxml_5etree_getNsTagWithEmptyNs)(PyObject *) = 0; +#define getNsTagWithEmptyNs __pyx_api_f_4lxml_5etree_getNsTagWithEmptyNs +static PyObject *(*__pyx_api_f_4lxml_5etree_namespacedName)(xmlNode *) = 0; +#define namespacedName __pyx_api_f_4lxml_5etree_namespacedName +static PyObject *(*__pyx_api_f_4lxml_5etree_namespacedNameFromNsName)(const xmlChar *, const xmlChar *) = 0; +#define namespacedNameFromNsName __pyx_api_f_4lxml_5etree_namespacedNameFromNsName +static void (*__pyx_api_f_4lxml_5etree_iteratorStoreNext)(struct LxmlElementIterator *, struct LxmlElement *) = 0; +#define iteratorStoreNext __pyx_api_f_4lxml_5etree_iteratorStoreNext +static void (*__pyx_api_f_4lxml_5etree_initTagMatch)(struct LxmlElementTagMatcher *, PyObject *) = 0; +#define initTagMatch __pyx_api_f_4lxml_5etree_initTagMatch +static xmlNs *(*__pyx_api_f_4lxml_5etree_findOrBuildNodeNsPrefix)(struct LxmlDocument *, xmlNode *, const xmlChar *, const xmlChar *) = 0; +#define findOrBuildNodeNsPrefix __pyx_api_f_4lxml_5etree_findOrBuildNodeNsPrefix +static int __Pyx_ImportFunction_3_1_4(PyObject *module, const char *funcname, void (**f)(void), const char *sig); + +#ifndef __PYX_HAVE_RT_ImportFunction_3_1_4 +#define __PYX_HAVE_RT_ImportFunction_3_1_4 +static int __Pyx_ImportFunction_3_1_4(PyObject *module, const char *funcname, void (**f)(void), const char *sig) { + PyObject *d = 0; + PyObject *cobj = 0; + union { + void (*fp)(void); + void *p; + } tmp; + d = PyObject_GetAttrString(module, "__pyx_capi__"); + if (!d) + goto bad; +#if (defined(Py_LIMITED_API) && Py_LIMITED_API >= 0x030d0000) || (!defined(Py_LIMITED_API) && PY_VERSION_HEX >= 0x030d0000) + PyDict_GetItemStringRef(d, funcname, &cobj); +#else + cobj = PyDict_GetItemString(d, funcname); + Py_XINCREF(cobj); +#endif + if (!cobj) { + PyErr_Format(PyExc_ImportError, + "%.200s does not export expected C function %.200s", + PyModule_GetName(module), funcname); + goto bad; + } + if (!PyCapsule_IsValid(cobj, sig)) { + PyErr_Format(PyExc_TypeError, + "C function %.200s.%.200s has wrong signature (expected %.500s, got %.500s)", + PyModule_GetName(module), funcname, sig, PyCapsule_GetName(cobj)); + goto bad; + } + tmp.p = PyCapsule_GetPointer(cobj, sig); + *f = tmp.fp; + if (!(*f)) + goto bad; + Py_DECREF(d); + Py_DECREF(cobj); + return 0; +bad: + Py_XDECREF(d); + Py_XDECREF(cobj); + return -1; +} +#endif + + +static int import_lxml__etree(void) { + PyObject *module = 0; + module = PyImport_ImportModule("lxml.etree"); + if (!module) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "deepcopyNodeToDocument", (void (**)(void))&__pyx_api_f_4lxml_5etree_deepcopyNodeToDocument, "struct LxmlElement *(struct LxmlDocument *, xmlNode *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "elementTreeFactory", (void (**)(void))&__pyx_api_f_4lxml_5etree_elementTreeFactory, "struct LxmlElementTree *(struct LxmlElement *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "newElementTree", (void (**)(void))&__pyx_api_f_4lxml_5etree_newElementTree, "struct LxmlElementTree *(struct LxmlElement *, PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "adoptExternalDocument", (void (**)(void))&__pyx_api_f_4lxml_5etree_adoptExternalDocument, "struct LxmlElementTree *(xmlDoc *, PyObject *, int)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "elementFactory", (void (**)(void))&__pyx_api_f_4lxml_5etree_elementFactory, "struct LxmlElement *(struct LxmlDocument *, xmlNode *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "makeElement", (void (**)(void))&__pyx_api_f_4lxml_5etree_makeElement, "struct LxmlElement *(PyObject *, struct LxmlDocument *, PyObject *, PyObject *, PyObject *, PyObject *, PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "makeSubElement", (void (**)(void))&__pyx_api_f_4lxml_5etree_makeSubElement, "struct LxmlElement *(struct LxmlElement *, PyObject *, PyObject *, PyObject *, PyObject *, PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "setElementClassLookupFunction", (void (**)(void))&__pyx_api_f_4lxml_5etree_setElementClassLookupFunction, "void (_element_class_lookup_function, PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "lookupDefaultElementClass", (void (**)(void))&__pyx_api_f_4lxml_5etree_lookupDefaultElementClass, "PyObject *(PyObject *, PyObject *, xmlNode *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "lookupNamespaceElementClass", (void (**)(void))&__pyx_api_f_4lxml_5etree_lookupNamespaceElementClass, "PyObject *(PyObject *, PyObject *, xmlNode *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "callLookupFallback", (void (**)(void))&__pyx_api_f_4lxml_5etree_callLookupFallback, "PyObject *(struct LxmlFallbackElementClassLookup *, struct LxmlDocument *, xmlNode *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "tagMatches", (void (**)(void))&__pyx_api_f_4lxml_5etree_tagMatches, "int (xmlNode *, const xmlChar *, const xmlChar *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "documentOrRaise", (void (**)(void))&__pyx_api_f_4lxml_5etree_documentOrRaise, "struct LxmlDocument *(PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "rootNodeOrRaise", (void (**)(void))&__pyx_api_f_4lxml_5etree_rootNodeOrRaise, "struct LxmlElement *(PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "hasText", (void (**)(void))&__pyx_api_f_4lxml_5etree_hasText, "int (xmlNode *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "hasTail", (void (**)(void))&__pyx_api_f_4lxml_5etree_hasTail, "int (xmlNode *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "textOf", (void (**)(void))&__pyx_api_f_4lxml_5etree_textOf, "PyObject *(xmlNode *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "tailOf", (void (**)(void))&__pyx_api_f_4lxml_5etree_tailOf, "PyObject *(xmlNode *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "setNodeText", (void (**)(void))&__pyx_api_f_4lxml_5etree_setNodeText, "int (xmlNode *, PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "setTailText", (void (**)(void))&__pyx_api_f_4lxml_5etree_setTailText, "int (xmlNode *, PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "attributeValue", (void (**)(void))&__pyx_api_f_4lxml_5etree_attributeValue, "PyObject *(xmlNode *, xmlAttr *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "attributeValueFromNsName", (void (**)(void))&__pyx_api_f_4lxml_5etree_attributeValueFromNsName, "PyObject *(xmlNode *, const xmlChar *, const xmlChar *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "getAttributeValue", (void (**)(void))&__pyx_api_f_4lxml_5etree_getAttributeValue, "PyObject *(struct LxmlElement *, PyObject *, PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "iterattributes", (void (**)(void))&__pyx_api_f_4lxml_5etree_iterattributes, "PyObject *(struct LxmlElement *, int)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "collectAttributes", (void (**)(void))&__pyx_api_f_4lxml_5etree_collectAttributes, "PyObject *(xmlNode *, int)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "setAttributeValue", (void (**)(void))&__pyx_api_f_4lxml_5etree_setAttributeValue, "int (struct LxmlElement *, PyObject *, PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "delAttribute", (void (**)(void))&__pyx_api_f_4lxml_5etree_delAttribute, "int (struct LxmlElement *, PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "delAttributeFromNsName", (void (**)(void))&__pyx_api_f_4lxml_5etree_delAttributeFromNsName, "int (xmlNode *, const xmlChar *, const xmlChar *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "hasChild", (void (**)(void))&__pyx_api_f_4lxml_5etree_hasChild, "int (xmlNode *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "findChild", (void (**)(void))&__pyx_api_f_4lxml_5etree_findChild, "xmlNode *(xmlNode *, Py_ssize_t)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "findChildForwards", (void (**)(void))&__pyx_api_f_4lxml_5etree_findChildForwards, "xmlNode *(xmlNode *, Py_ssize_t)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "findChildBackwards", (void (**)(void))&__pyx_api_f_4lxml_5etree_findChildBackwards, "xmlNode *(xmlNode *, Py_ssize_t)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "nextElement", (void (**)(void))&__pyx_api_f_4lxml_5etree_nextElement, "xmlNode *(xmlNode *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "previousElement", (void (**)(void))&__pyx_api_f_4lxml_5etree_previousElement, "xmlNode *(xmlNode *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "appendChild", (void (**)(void))&__pyx_api_f_4lxml_5etree_appendChild, "void (struct LxmlElement *, struct LxmlElement *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "appendChildToElement", (void (**)(void))&__pyx_api_f_4lxml_5etree_appendChildToElement, "int (struct LxmlElement *, struct LxmlElement *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "pyunicode", (void (**)(void))&__pyx_api_f_4lxml_5etree_pyunicode, "PyObject *(const xmlChar *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "utf8", (void (**)(void))&__pyx_api_f_4lxml_5etree_utf8, "PyObject *(PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "getNsTag", (void (**)(void))&__pyx_api_f_4lxml_5etree_getNsTag, "PyObject *(PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "getNsTagWithEmptyNs", (void (**)(void))&__pyx_api_f_4lxml_5etree_getNsTagWithEmptyNs, "PyObject *(PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "namespacedName", (void (**)(void))&__pyx_api_f_4lxml_5etree_namespacedName, "PyObject *(xmlNode *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "namespacedNameFromNsName", (void (**)(void))&__pyx_api_f_4lxml_5etree_namespacedNameFromNsName, "PyObject *(const xmlChar *, const xmlChar *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "iteratorStoreNext", (void (**)(void))&__pyx_api_f_4lxml_5etree_iteratorStoreNext, "void (struct LxmlElementIterator *, struct LxmlElement *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "initTagMatch", (void (**)(void))&__pyx_api_f_4lxml_5etree_initTagMatch, "void (struct LxmlElementTagMatcher *, PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_1_4(module, "findOrBuildNodeNsPrefix", (void (**)(void))&__pyx_api_f_4lxml_5etree_findOrBuildNodeNsPrefix, "xmlNs *(struct LxmlDocument *, xmlNode *, const xmlChar *, const xmlChar *)") < 0) goto bad; + Py_DECREF(module); module = 0; + return 0; + bad: + Py_XDECREF(module); + return -1; +} + +#endif /* !__PYX_HAVE_API__lxml__etree */ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/extensions.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/extensions.pxi new file mode 100644 index 0000000000000000000000000000000000000000..ab687bec9c1d58c1220fae31bce1712d4751a9f2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/extensions.pxi @@ -0,0 +1,830 @@ +# support for extension functions in XPath and XSLT + +cdef class XPathError(LxmlError): + """Base class of all XPath errors. + """ + +cdef class XPathEvalError(XPathError): + """Error during XPath evaluation. + """ + +cdef class XPathFunctionError(XPathEvalError): + """Internal error looking up an XPath extension function. + """ + +cdef class XPathResultError(XPathEvalError): + """Error handling an XPath result. + """ + + +# forward declarations + +ctypedef int (*_register_function)(void* ctxt, name_utf, ns_uri_utf) +cdef class _ExsltRegExp + +################################################################################ +# Base class for XSLT and XPath evaluation contexts: functions, namespaces, ... + +@cython.internal +cdef class _BaseContext: + cdef xpath.xmlXPathContext* _xpathCtxt + cdef _Document _doc + cdef dict _extensions + cdef list _namespaces + cdef list _global_namespaces + cdef dict _utf_refs + cdef dict _function_cache + cdef dict _eval_context_dict + cdef bint _build_smart_strings + # for exception handling and temporary reference keeping: + cdef _TempStore _temp_refs + cdef set _temp_documents + cdef _ExceptionContext _exc + cdef _ErrorLog _error_log + + def __init__(self, namespaces, extensions, error_log, enable_regexp, + build_smart_strings): + cdef _ExsltRegExp _regexp + cdef dict new_extensions + cdef list ns + self._utf_refs = {} + self._global_namespaces = [] + self._function_cache = {} + self._eval_context_dict = None + self._error_log = error_log + + if extensions is not None: + # convert extensions to UTF-8 + if isinstance(extensions, dict): + extensions = (extensions,) + # format: [ {(ns, name):function} ] -> {(ns_utf, name_utf):function} + new_extensions = {} + for extension in extensions: + for (ns_uri, name), function in extension.items(): + if name is None: + raise ValueError, "extensions must have non empty names" + ns_utf = self._to_utf(ns_uri) + name_utf = self._to_utf(name) + new_extensions[(ns_utf, name_utf)] = function + extensions = new_extensions or None + + if namespaces is not None: + if isinstance(namespaces, dict): + namespaces = namespaces.items() + if namespaces: + ns = [] + for prefix, ns_uri in namespaces: + if prefix is None or not prefix: + raise TypeError, \ + "empty namespace prefix is not supported in XPath" + if ns_uri is None or not ns_uri: + raise TypeError, \ + "setting default namespace is not supported in XPath" + prefix_utf = self._to_utf(prefix) + ns_uri_utf = self._to_utf(ns_uri) + ns.append( (prefix_utf, ns_uri_utf) ) + namespaces = ns + else: + namespaces = None + + self._doc = None + self._exc = _ExceptionContext() + self._extensions = extensions + self._namespaces = namespaces + self._temp_refs = _TempStore() + self._temp_documents = set() + self._build_smart_strings = build_smart_strings + + if enable_regexp: + _regexp = _ExsltRegExp() + _regexp._register_in_context(self) + + cdef _BaseContext _copy(self): + cdef _BaseContext context + if self._namespaces is not None: + namespaces = self._namespaces[:] + else: + namespaces = None + context = self.__class__(namespaces, None, self._error_log, False, + self._build_smart_strings) + if self._extensions is not None: + context._extensions = self._extensions.copy() + return context + + cdef bytes _to_utf(self, s): + "Convert to UTF-8 and keep a reference to the encoded string" + cdef python.PyObject* dict_result + if s is None: + return None + dict_result = python.PyDict_GetItem(self._utf_refs, s) + if dict_result is not NULL: + return dict_result + utf = _utf8(s) + self._utf_refs[s] = utf + if python.IS_PYPY: + # use C level refs, PyPy refs are not enough! + python.Py_INCREF(utf) + return utf + + cdef void _set_xpath_context(self, xpath.xmlXPathContext* xpathCtxt) noexcept: + self._xpathCtxt = xpathCtxt + xpathCtxt.userData = self + # Need a cast here because older libxml2 releases do not use 'const' in the functype. + xpathCtxt.error = _receiveXPathError + + @cython.final + cdef _register_context(self, _Document doc): + self._doc = doc + self._exc.clear() + + @cython.final + cdef _cleanup_context(self): + #xpath.xmlXPathRegisteredNsCleanup(self._xpathCtxt) + #self.unregisterGlobalNamespaces() + if python.IS_PYPY: + # clean up double refs in PyPy (see "_to_utf()" method) + for ref in self._utf_refs.itervalues(): + python.Py_DECREF(ref) + self._utf_refs.clear() + self._eval_context_dict = None + self._doc = None + + @cython.final + cdef _release_context(self): + if self._xpathCtxt is not NULL: + self._xpathCtxt.userData = NULL + self._xpathCtxt = NULL + + # namespaces (internal UTF-8 methods with leading '_') + + cdef addNamespace(self, prefix, ns_uri): + cdef list namespaces + if prefix is None: + raise TypeError, "empty prefix is not supported in XPath" + prefix_utf = self._to_utf(prefix) + ns_uri_utf = self._to_utf(ns_uri) + new_item = (prefix_utf, ns_uri_utf) + if self._namespaces is None: + self._namespaces = [new_item] + else: + namespaces = [] + for item in self._namespaces: + if item[0] == prefix_utf: + item = new_item + new_item = None + namespaces.append(item) + if new_item is not None: + namespaces.append(new_item) + self._namespaces = namespaces + if self._xpathCtxt is not NULL: + xpath.xmlXPathRegisterNs( + self._xpathCtxt, _xcstr(prefix_utf), _xcstr(ns_uri_utf)) + + cdef registerNamespace(self, prefix, ns_uri): + if prefix is None: + raise TypeError, "empty prefix is not supported in XPath" + prefix_utf = self._to_utf(prefix) + ns_uri_utf = self._to_utf(ns_uri) + self._global_namespaces.append(prefix_utf) + xpath.xmlXPathRegisterNs(self._xpathCtxt, + _xcstr(prefix_utf), _xcstr(ns_uri_utf)) + + cdef registerLocalNamespaces(self): + if self._namespaces is None: + return + for prefix_utf, ns_uri_utf in self._namespaces: + xpath.xmlXPathRegisterNs( + self._xpathCtxt, _xcstr(prefix_utf), _xcstr(ns_uri_utf)) + + cdef registerGlobalNamespaces(self): + cdef list ns_prefixes = _find_all_extension_prefixes() + if python.PyList_GET_SIZE(ns_prefixes) > 0: + for prefix_utf, ns_uri_utf in ns_prefixes: + self._global_namespaces.append(prefix_utf) + xpath.xmlXPathRegisterNs( + self._xpathCtxt, _xcstr(prefix_utf), _xcstr(ns_uri_utf)) + + cdef unregisterGlobalNamespaces(self): + if python.PyList_GET_SIZE(self._global_namespaces) > 0: + for prefix_utf in self._global_namespaces: + xpath.xmlXPathRegisterNs(self._xpathCtxt, + _xcstr(prefix_utf), NULL) + del self._global_namespaces[:] + + cdef void _unregisterNamespace(self, prefix_utf) noexcept: + xpath.xmlXPathRegisterNs(self._xpathCtxt, + _xcstr(prefix_utf), NULL) + + # extension functions + + cdef int _addLocalExtensionFunction(self, ns_utf, name_utf, function) except -1: + if self._extensions is None: + self._extensions = {} + self._extensions[(ns_utf, name_utf)] = function + return 0 + + cdef registerGlobalFunctions(self, void* ctxt, + _register_function reg_func): + cdef python.PyObject* dict_result + cdef dict d + for ns_utf, ns_functions in __FUNCTION_NAMESPACE_REGISTRIES.iteritems(): + dict_result = python.PyDict_GetItem( + self._function_cache, ns_utf) + if dict_result is not NULL: + d = dict_result + else: + d = {} + self._function_cache[ns_utf] = d + for name_utf, function in ns_functions.iteritems(): + d[name_utf] = function + reg_func(ctxt, name_utf, ns_utf) + + cdef registerLocalFunctions(self, void* ctxt, + _register_function reg_func): + cdef python.PyObject* dict_result + cdef dict d + if self._extensions is None: + return # done + last_ns = None + d = None + for (ns_utf, name_utf), function in self._extensions.iteritems(): + if ns_utf is not last_ns or d is None: + last_ns = ns_utf + dict_result = python.PyDict_GetItem( + self._function_cache, ns_utf) + if dict_result is not NULL: + d = dict_result + else: + d = {} + self._function_cache[ns_utf] = d + d[name_utf] = function + reg_func(ctxt, name_utf, ns_utf) + + cdef unregisterAllFunctions(self, void* ctxt, + _register_function unreg_func): + for ns_utf, functions in self._function_cache.iteritems(): + for name_utf in functions: + unreg_func(ctxt, name_utf, ns_utf) + + cdef unregisterGlobalFunctions(self, void* ctxt, + _register_function unreg_func): + for ns_utf, functions in self._function_cache.items(): + for name_utf in functions: + if self._extensions is None or \ + (ns_utf, name_utf) not in self._extensions: + unreg_func(ctxt, name_utf, ns_utf) + + @cython.final + cdef _find_cached_function(self, const_xmlChar* c_ns_uri, const_xmlChar* c_name): + """Lookup an extension function in the cache and return it. + + Parameters: c_ns_uri may be NULL, c_name must not be NULL + """ + cdef python.PyObject* c_dict + cdef python.PyObject* dict_result + c_dict = python.PyDict_GetItem( + self._function_cache, None if c_ns_uri is NULL else c_ns_uri) + if c_dict is not NULL: + dict_result = python.PyDict_GetItem( + c_dict, c_name) + if dict_result is not NULL: + return dict_result + return None + + # Python access to the XPath context for extension functions + + @property + def context_node(self): + cdef xmlNode* c_node + if self._xpathCtxt is NULL: + raise XPathError, \ + "XPath context is only usable during the evaluation" + c_node = self._xpathCtxt.node + if c_node is NULL: + raise XPathError, "no context node" + if c_node.doc != self._xpathCtxt.doc: + raise XPathError, \ + "document-external context nodes are not supported" + if self._doc is None: + raise XPathError, "document context is missing" + return _elementFactory(self._doc, c_node) + + @property + def eval_context(self): + if self._eval_context_dict is None: + self._eval_context_dict = {} + return self._eval_context_dict + + # Python reference keeping during XPath function evaluation + + @cython.final + cdef _release_temp_refs(self): + "Free temporarily referenced objects from this context." + self._temp_refs.clear() + self._temp_documents.clear() + + @cython.final + cdef _hold(self, obj): + """A way to temporarily hold references to nodes in the evaluator. + + This is needed because otherwise nodes created in XPath extension + functions would be reference counted too soon, during the XPath + evaluation. This is most important in the case of exceptions. + """ + cdef _Element element + if isinstance(obj, _Element): + self._temp_refs.add(obj) + self._temp_documents.add((<_Element>obj)._doc) + return + elif _isString(obj) or not python.PySequence_Check(obj): + return + for o in obj: + if isinstance(o, _Element): + #print "Holding element:", element._c_node + self._temp_refs.add(o) + #print "Holding document:", element._doc._c_doc + self._temp_documents.add((<_Element>o)._doc) + + @cython.final + cdef _Document _findDocumentForNode(self, xmlNode* c_node): + """If an XPath expression returns an element from a different + document than the current context document, we call this to + see if it was possibly created by an extension and is a known + document instance. + """ + cdef _Document doc + for doc in self._temp_documents: + if doc is not None and doc._c_doc is c_node.doc: + return doc + return None + + +# libxml2 keeps these error messages in a static array in its code +# and doesn't give us access to them ... + +cdef tuple LIBXML2_XPATH_ERROR_MESSAGES = ( + b"Ok", + b"Number encoding", + b"Unfinished literal", + b"Start of literal", + b"Expected $ for variable reference", + b"Undefined variable", + b"Invalid predicate", + b"Invalid expression", + b"Missing closing curly brace", + b"Unregistered function", + b"Invalid operand", + b"Invalid type", + b"Invalid number of arguments", + b"Invalid context size", + b"Invalid context position", + b"Memory allocation error", + b"Syntax error", + b"Resource error", + b"Sub resource error", + b"Undefined namespace prefix", + b"Encoding error", + b"Char out of XML range", + b"Invalid or incomplete context", + b"Stack usage error", + b"Forbidden variable\n", + b"?? Unknown error ??\n", +) + +cdef void _forwardXPathError(void* c_ctxt, const xmlerror.xmlError* c_error) noexcept with gil: + cdef xmlerror.xmlError error + cdef int xpath_code + if c_error.message is not NULL: + error.message = c_error.message + else: + xpath_code = c_error.code - xmlerror.XML_XPATH_EXPRESSION_OK + if 0 <= xpath_code < len(LIBXML2_XPATH_ERROR_MESSAGES): + error.message = _cstr(LIBXML2_XPATH_ERROR_MESSAGES[xpath_code]) + else: + error.message = b"unknown error" + error.domain = c_error.domain + error.code = c_error.code + error.level = c_error.level + error.line = c_error.line + error.int2 = c_error.int1 # column + error.file = c_error.file + error.node = NULL + + (<_BaseContext>c_ctxt)._error_log._receive(&error) + +cdef void _receiveXPathError(void* c_context, const xmlerror.xmlError* error) noexcept nogil: + if not __DEBUG: + return + if c_context is NULL: + _forwardError(NULL, error) + else: + _forwardXPathError(c_context, error) + + +def Extension(module, function_mapping=None, *, ns=None): + """Extension(module, function_mapping=None, ns=None) + + Build a dictionary of extension functions from the functions + defined in a module or the methods of an object. + + As second argument, you can pass an additional mapping of + attribute names to XPath function names, or a list of function + names that should be taken. + + The ``ns`` keyword argument accepts a namespace URI for the XPath + functions. + """ + cdef dict functions = {} + if isinstance(function_mapping, dict): + for function_name, xpath_name in function_mapping.items(): + functions[(ns, xpath_name)] = getattr(module, function_name) + else: + if function_mapping is None: + function_mapping = [ name for name in dir(module) + if not name.startswith('_') ] + for function_name in function_mapping: + functions[(ns, function_name)] = getattr(module, function_name) + return functions + +################################################################################ +# EXSLT regexp implementation + +@cython.final +@cython.internal +cdef class _ExsltRegExp: + cdef dict _compile_map + def __cinit__(self): + self._compile_map = {} + + cdef _make_string(self, value): + if _isString(value): + return value + elif isinstance(value, list): + # node set: take recursive text concatenation of first element + if python.PyList_GET_SIZE(value) == 0: + return '' + firstnode = value[0] + if _isString(firstnode): + return firstnode + elif isinstance(firstnode, _Element): + c_text = tree.xmlNodeGetContent((<_Element>firstnode)._c_node) + if c_text is NULL: + raise MemoryError() + try: + return funicode(c_text) + finally: + tree.xmlFree(c_text) + else: + return unicode(firstnode) + else: + return unicode(value) + + cdef _compile(self, rexp, ignore_case): + cdef python.PyObject* c_result + rexp = self._make_string(rexp) + key = (rexp, ignore_case) + c_result = python.PyDict_GetItem(self._compile_map, key) + if c_result is not NULL: + return c_result + py_flags = re.UNICODE + if ignore_case: + py_flags = py_flags | re.IGNORECASE + rexp_compiled = re.compile(rexp, py_flags) + self._compile_map[key] = rexp_compiled + return rexp_compiled + + def test(self, ctxt, s, rexp, flags=''): + flags = self._make_string(flags) + s = self._make_string(s) + rexpc = self._compile(rexp, 'i' in flags) + if rexpc.search(s) is None: + return False + else: + return True + + def match(self, ctxt, s, rexp, flags=''): + cdef list result_list + flags = self._make_string(flags) + s = self._make_string(s) + rexpc = self._compile(rexp, 'i' in flags) + if 'g' in flags: + results = rexpc.findall(s) + if not results: + return () + else: + result = rexpc.search(s) + if not result: + return () + results = [ result.group() ] + results.extend( result.groups('') ) + result_list = [] + root = Element('matches') + for s_match in results: + if python.PyTuple_CheckExact(s_match): + s_match = ''.join(s_match) + elem = SubElement(root, 'match') + elem.text = s_match + result_list.append(elem) + return result_list + + def replace(self, ctxt, s, rexp, flags, replacement): + replacement = self._make_string(replacement) + flags = self._make_string(flags) + s = self._make_string(s) + rexpc = self._compile(rexp, 'i' in flags) + count: object = 0 if 'g' in flags else 1 + return rexpc.sub(replacement, s, count) + + cdef _register_in_context(self, _BaseContext context): + ns = b"http://exslt.org/regular-expressions" + context._addLocalExtensionFunction(ns, b"test", self.test) + context._addLocalExtensionFunction(ns, b"match", self.match) + context._addLocalExtensionFunction(ns, b"replace", self.replace) + + +################################################################################ +# helper functions + +cdef xpath.xmlXPathObject* _wrapXPathObject(object obj, _Document doc, + _BaseContext context) except NULL: + cdef xpath.xmlNodeSet* resultSet + cdef _Element fake_node = None + cdef xmlNode* c_node + + if isinstance(obj, unicode): + obj = _utf8(obj) + if isinstance(obj, bytes): + # libxml2 copies the string value + return xpath.xmlXPathNewCString(_cstr(obj)) + if isinstance(obj, bool): + return xpath.xmlXPathNewBoolean(obj) + if python.PyNumber_Check(obj): + return xpath.xmlXPathNewFloat(obj) + if obj is None: + resultSet = xpath.xmlXPathNodeSetCreate(NULL) + elif isinstance(obj, _Element): + resultSet = xpath.xmlXPathNodeSetCreate((<_Element>obj)._c_node) + elif python.PySequence_Check(obj): + resultSet = xpath.xmlXPathNodeSetCreate(NULL) + try: + for value in obj: + if isinstance(value, _Element): + if context is not None: + context._hold(value) + xpath.xmlXPathNodeSetAdd(resultSet, (<_Element>value)._c_node) + else: + if context is None or doc is None: + raise XPathResultError, \ + f"Non-Element values not supported at this point - got {value!r}" + # support strings by appending text nodes to an Element + if isinstance(value, unicode): + value = _utf8(value) + if isinstance(value, bytes): + if fake_node is None: + fake_node = _makeElement("text-root", NULL, doc, None, + None, None, None, None, None) + context._hold(fake_node) + else: + # append a comment node to keep the text nodes separate + c_node = tree.xmlNewDocComment(doc._c_doc, "") + if c_node is NULL: + raise MemoryError() + tree.xmlAddChild(fake_node._c_node, c_node) + context._hold(value) + c_node = tree.xmlNewDocText(doc._c_doc, _xcstr(value)) + if c_node is NULL: + raise MemoryError() + tree.xmlAddChild(fake_node._c_node, c_node) + xpath.xmlXPathNodeSetAdd(resultSet, c_node) + else: + raise XPathResultError, \ + f"This is not a supported node-set result: {value!r}" + except: + xpath.xmlXPathFreeNodeSet(resultSet) + raise + else: + raise XPathResultError, f"Unknown return type: {python._fqtypename(obj).decode('utf8')}" + return xpath.xmlXPathWrapNodeSet(resultSet) + +cdef object _unwrapXPathObject(xpath.xmlXPathObject* xpathObj, + _Document doc, _BaseContext context): + if xpathObj.type == xpath.XPATH_UNDEFINED: + raise XPathResultError, "Undefined xpath result" + elif xpathObj.type == xpath.XPATH_NODESET: + return _createNodeSetResult(xpathObj, doc, context) + elif xpathObj.type == xpath.XPATH_BOOLEAN: + return xpathObj.boolval + elif xpathObj.type == xpath.XPATH_NUMBER: + return xpathObj.floatval + elif xpathObj.type == xpath.XPATH_STRING: + stringval = funicode(xpathObj.stringval) + if context._build_smart_strings: + stringval = _elementStringResultFactory( + stringval, None, None, False) + return stringval + elif xpathObj.type == xpath.XPATH_POINT: + raise NotImplementedError, "XPATH_POINT" + elif xpathObj.type == xpath.XPATH_RANGE: + raise NotImplementedError, "XPATH_RANGE" + elif xpathObj.type == xpath.XPATH_LOCATIONSET: + raise NotImplementedError, "XPATH_LOCATIONSET" + elif xpathObj.type == xpath.XPATH_USERS: + raise NotImplementedError, "XPATH_USERS" + elif xpathObj.type == xpath.XPATH_XSLT_TREE: + return _createNodeSetResult(xpathObj, doc, context) + else: + raise XPathResultError, f"Unknown xpath result {xpathObj.type}" + +cdef object _createNodeSetResult(xpath.xmlXPathObject* xpathObj, _Document doc, + _BaseContext context): + cdef xmlNode* c_node + cdef int i + cdef list result + result = [] + if xpathObj.nodesetval is NULL: + return result + for i in range(xpathObj.nodesetval.nodeNr): + c_node = xpathObj.nodesetval.nodeTab[i] + _unpackNodeSetEntry(result, c_node, doc, context, + xpathObj.type == xpath.XPATH_XSLT_TREE) + return result + +cdef _unpackNodeSetEntry(list results, xmlNode* c_node, _Document doc, + _BaseContext context, bint is_fragment): + cdef xmlNode* c_child + if _isElement(c_node): + if c_node.doc != doc._c_doc and c_node.doc._private is NULL: + # XXX: works, but maybe not always the right thing to do? + # XPath: only runs when extensions create or copy trees + # -> we store Python refs to these, so that is OK + # XSLT: can it leak when merging trees from multiple sources? + c_node = tree.xmlDocCopyNode(c_node, doc._c_doc, 1) + # FIXME: call _instantiateElementFromXPath() instead? + results.append( + _fakeDocElementFactory(doc, c_node)) + elif c_node.type == tree.XML_TEXT_NODE or \ + c_node.type == tree.XML_CDATA_SECTION_NODE or \ + c_node.type == tree.XML_ATTRIBUTE_NODE: + results.append( + _buildElementStringResult(doc, c_node, context)) + elif c_node.type == tree.XML_NAMESPACE_DECL: + results.append( (funicodeOrNone((c_node).prefix), + funicodeOrNone((c_node).href)) ) + elif c_node.type == tree.XML_DOCUMENT_NODE or \ + c_node.type == tree.XML_HTML_DOCUMENT_NODE: + # ignored for everything but result tree fragments + if is_fragment: + c_child = c_node.children + while c_child is not NULL: + _unpackNodeSetEntry(results, c_child, doc, context, 0) + c_child = c_child.next + elif c_node.type == tree.XML_XINCLUDE_START or \ + c_node.type == tree.XML_XINCLUDE_END: + pass + else: + raise NotImplementedError, \ + f"Not yet implemented result node type: {c_node.type}" + +cdef void _freeXPathObject(xpath.xmlXPathObject* xpathObj) noexcept: + """Free the XPath object, but *never* free the *content* of node sets. + Python dealloc will do that for us. + """ + if xpathObj.nodesetval is not NULL: + xpath.xmlXPathFreeNodeSet(xpathObj.nodesetval) + xpathObj.nodesetval = NULL + xpath.xmlXPathFreeObject(xpathObj) + +cdef _Element _instantiateElementFromXPath(xmlNode* c_node, _Document doc, + _BaseContext context): + # NOTE: this may copy the element - only call this when it can't leak + if c_node.doc != doc._c_doc and c_node.doc._private is NULL: + # not from the context document and not from a fake document + # either => may still be from a known document, e.g. one + # created by an extension function + node_doc = context._findDocumentForNode(c_node) + if node_doc is None: + # not from a known document at all! => can only make a + # safety copy here + c_node = tree.xmlDocCopyNode(c_node, doc._c_doc, 1) + else: + doc = node_doc + return _fakeDocElementFactory(doc, c_node) + +################################################################################ +# special str/unicode subclasses + +@cython.final +cdef class _ElementUnicodeResult(unicode): + cdef _Element _parent + cdef readonly object attrname + cdef readonly bint is_tail + + def getparent(self): + return self._parent + + @property + def is_text(self): + return self._parent is not None and not (self.is_tail or self.attrname is not None) + + @property + def is_attribute(self): + return self.attrname is not None + +cdef object _elementStringResultFactory(string_value, _Element parent, + attrname, bint is_tail): + result = _ElementUnicodeResult(string_value) + result._parent = parent + result.is_tail = is_tail + result.attrname = attrname + return result + +cdef object _buildElementStringResult(_Document doc, xmlNode* c_node, + _BaseContext context): + cdef _Element parent = None + cdef object attrname = None + cdef xmlNode* c_element + cdef bint is_tail + + if c_node.type == tree.XML_ATTRIBUTE_NODE: + attrname = _namespacedName(c_node) + is_tail = 0 + s = tree.xmlNodeGetContent(c_node) + try: + value = funicode(s) + finally: + tree.xmlFree(s) + c_element = NULL + else: + #assert c_node.type == tree.XML_TEXT_NODE or c_node.type == tree.XML_CDATA_SECTION_NODE, "invalid node type" + # may be tail text or normal text + value = funicode(c_node.content) + c_element = _previousElement(c_node) + is_tail = c_element is not NULL + + if not context._build_smart_strings: + return value + + if c_element is NULL: + # non-tail text or attribute text + c_element = c_node.parent + while c_element is not NULL and not _isElement(c_element): + c_element = c_element.parent + + if c_element is not NULL: + parent = _instantiateElementFromXPath(c_element, doc, context) + + return _elementStringResultFactory( + value, parent, attrname, is_tail) + +################################################################################ +# callbacks for XPath/XSLT extension functions + +cdef void _extension_function_call(_BaseContext context, function, + xpath.xmlXPathParserContext* ctxt, int nargs) noexcept: + cdef _Document doc + cdef xpath.xmlXPathObject* obj + cdef list args + cdef int i + doc = context._doc + try: + args = [] + for i in range(nargs): + obj = xpath.valuePop(ctxt) + o = _unwrapXPathObject(obj, doc, context) + _freeXPathObject(obj) + args.append(o) + args.reverse() + + res = function(context, *args) + # wrap result for XPath consumption + obj = _wrapXPathObject(res, doc, context) + # prevent Python from deallocating elements handed to libxml2 + context._hold(res) + xpath.valuePush(ctxt, obj) + except: + xpath.xmlXPathErr(ctxt, xpath.XPATH_EXPR_ERROR) + context._exc._store_raised() + finally: + return # swallow any further exceptions + +# lookup the function by name and call it + +cdef void _xpath_function_call(xpath.xmlXPathParserContext* ctxt, + int nargs) noexcept with gil: + cdef _BaseContext context + cdef xpath.xmlXPathContext* rctxt = ctxt.context + context = <_BaseContext> rctxt.userData + try: + function = context._find_cached_function(rctxt.functionURI, rctxt.function) + if function is not None: + _extension_function_call(context, function, ctxt, nargs) + else: + xpath.xmlXPathErr(ctxt, xpath.XPATH_UNKNOWN_FUNC_ERROR) + context._exc._store_exception(XPathFunctionError( + f"XPath function '{_namespacedNameFromNsName(rctxt.functionURI, rctxt.function)}' not found")) + except: + # may not be the right error, but we need to tell libxml2 *something* + xpath.xmlXPathErr(ctxt, xpath.XPATH_UNKNOWN_FUNC_ERROR) + context._exc._store_raised() + finally: + return # swallow any further exceptions diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/lxml.etree.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/lxml.etree.h new file mode 100644 index 0000000000000000000000000000000000000000..17b99a7be5c4159429d575c9e98f621f57c8310c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/lxml.etree.h @@ -0,0 +1,244 @@ +/* Generated by Cython 3.1.4 */ + +#ifndef __PYX_HAVE__lxml__etree +#define __PYX_HAVE__lxml__etree + +#include "Python.h" +struct LxmlDocument; +struct LxmlElement; +struct LxmlElementTree; +struct LxmlElementTagMatcher; +struct LxmlElementIterator; +struct LxmlElementBase; +struct LxmlElementClassLookup; +struct LxmlFallbackElementClassLookup; + +/* "lxml/etree.pyx":451 + * + * # type of a function that steps from node to node + * ctypedef public xmlNode* (*_node_to_node_function)(xmlNode*) # <<<<<<<<<<<<<< + * + * +*/ +typedef xmlNode *(*_node_to_node_function)(xmlNode *); + +/* "lxml/etree.pyx":465 + * # Public Python API + * + * @cython.final # <<<<<<<<<<<<<< + * @cython.freelist(8) + * cdef public class _Document [ type LxmlDocumentType, object LxmlDocument ]: +*/ +struct LxmlDocument { + PyObject_HEAD + struct __pyx_vtabstruct_4lxml_5etree__Document *__pyx_vtab; + int _ns_counter; + PyObject *_prefix_tail; + xmlDoc *_c_doc; + struct __pyx_obj_4lxml_5etree__BaseParser *_parser; +}; + +/* "lxml/etree.pyx":817 + * + * + * @cython.no_gc_clear # <<<<<<<<<<<<<< + * cdef public class _Element [ type LxmlElementType, object LxmlElement ]: + * """Element class. +*/ +struct LxmlElement { + PyObject_HEAD + struct LxmlDocument *_doc; + xmlNode *_c_node; + PyObject *_tag; +}; + +/* "lxml/etree.pyx":1991 + * + * + * cdef public class _ElementTree [ type LxmlElementTreeType, # <<<<<<<<<<<<<< + * object LxmlElementTree ]: + * cdef _Document _doc +*/ +struct LxmlElementTree { + PyObject_HEAD + struct __pyx_vtabstruct_4lxml_5etree__ElementTree *__pyx_vtab; + struct LxmlDocument *_doc; + struct LxmlElement *_context_node; +}; + +/* "lxml/etree.pyx":2765 + * + * + * cdef public class _ElementTagMatcher [ object LxmlElementTagMatcher, # <<<<<<<<<<<<<< + * type LxmlElementTagMatcherType ]: + * """ +*/ +struct LxmlElementTagMatcher { + PyObject_HEAD + struct __pyx_vtabstruct_4lxml_5etree__ElementTagMatcher *__pyx_vtab; + PyObject *_pystrings; + int _node_type; + char *_href; + char *_name; +}; + +/* "lxml/etree.pyx":2796 + * self._name = NULL + * + * cdef public class _ElementIterator(_ElementTagMatcher) [ # <<<<<<<<<<<<<< + * object LxmlElementIterator, type LxmlElementIteratorType ]: + * """ +*/ +struct LxmlElementIterator { + struct LxmlElementTagMatcher __pyx_base; + struct LxmlElement *_node; + _node_to_node_function _next_element; +}; + +/* "src/lxml/classlookup.pxi":6 + * # Custom Element classes + * + * cdef public class ElementBase(_Element) [ type LxmlElementBaseType, # <<<<<<<<<<<<<< + * object LxmlElementBase ]: + * """ElementBase(*children, attrib=None, nsmap=None, **_extra) +*/ +struct LxmlElementBase { + struct LxmlElement __pyx_base; +}; + +/* "src/lxml/classlookup.pxi":210 + * # Element class lookup + * + * ctypedef public object (*_element_class_lookup_function)(object, _Document, xmlNode*) # <<<<<<<<<<<<<< + * + * # class to store element class lookup functions +*/ +typedef PyObject *(*_element_class_lookup_function)(PyObject *, struct LxmlDocument *, xmlNode *); + +/* "src/lxml/classlookup.pxi":213 + * + * # class to store element class lookup functions + * cdef public class ElementClassLookup [ type LxmlElementClassLookupType, # <<<<<<<<<<<<<< + * object LxmlElementClassLookup ]: + * """ElementClassLookup(self) +*/ +struct LxmlElementClassLookup { + PyObject_HEAD + _element_class_lookup_function _lookup_function; +}; + +/* "src/lxml/classlookup.pxi":221 + * + * + * cdef public class FallbackElementClassLookup(ElementClassLookup) \ # <<<<<<<<<<<<<< + * [ type LxmlFallbackElementClassLookupType, + * object LxmlFallbackElementClassLookup ]: +*/ +struct LxmlFallbackElementClassLookup { + struct LxmlElementClassLookup __pyx_base; + struct __pyx_vtabstruct_4lxml_5etree_FallbackElementClassLookup *__pyx_vtab; + struct LxmlElementClassLookup *fallback; + _element_class_lookup_function _fallback_function; +}; + +#ifndef __PYX_HAVE_API__lxml__etree + +#ifdef CYTHON_EXTERN_C + #undef __PYX_EXTERN_C + #define __PYX_EXTERN_C CYTHON_EXTERN_C +#elif defined(__PYX_EXTERN_C) + #ifdef _MSC_VER + #pragma message ("Please do not define the '__PYX_EXTERN_C' macro externally. Use 'CYTHON_EXTERN_C' instead.") + #else + #warning Please do not define the '__PYX_EXTERN_C' macro externally. Use 'CYTHON_EXTERN_C' instead. + #endif +#else + #ifdef __cplusplus + #define __PYX_EXTERN_C extern "C" + #else + #define __PYX_EXTERN_C extern + #endif +#endif + +#ifndef DL_IMPORT + #define DL_IMPORT(_T) _T +#endif + +__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlDocumentType; +__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementType; +__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementTreeType; +__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementTagMatcherType; +__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementIteratorType; +__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementBaseType; +__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlElementClassLookupType; +__PYX_EXTERN_C DL_IMPORT(PyTypeObject) LxmlFallbackElementClassLookupType; + +__PYX_EXTERN_C struct LxmlElement *deepcopyNodeToDocument(struct LxmlDocument *, xmlNode *); +__PYX_EXTERN_C struct LxmlElementTree *elementTreeFactory(struct LxmlElement *); +__PYX_EXTERN_C struct LxmlElementTree *newElementTree(struct LxmlElement *, PyObject *); +__PYX_EXTERN_C struct LxmlElementTree *adoptExternalDocument(xmlDoc *, PyObject *, int); +__PYX_EXTERN_C struct LxmlElement *elementFactory(struct LxmlDocument *, xmlNode *); +__PYX_EXTERN_C struct LxmlElement *makeElement(PyObject *, struct LxmlDocument *, PyObject *, PyObject *, PyObject *, PyObject *, PyObject *); +__PYX_EXTERN_C struct LxmlElement *makeSubElement(struct LxmlElement *, PyObject *, PyObject *, PyObject *, PyObject *, PyObject *); +__PYX_EXTERN_C void setElementClassLookupFunction(_element_class_lookup_function, PyObject *); +__PYX_EXTERN_C PyObject *lookupDefaultElementClass(PyObject *, PyObject *, xmlNode *); +__PYX_EXTERN_C PyObject *lookupNamespaceElementClass(PyObject *, PyObject *, xmlNode *); +__PYX_EXTERN_C PyObject *callLookupFallback(struct LxmlFallbackElementClassLookup *, struct LxmlDocument *, xmlNode *); +__PYX_EXTERN_C int tagMatches(xmlNode *, const xmlChar *, const xmlChar *); +__PYX_EXTERN_C struct LxmlDocument *documentOrRaise(PyObject *); +__PYX_EXTERN_C struct LxmlElement *rootNodeOrRaise(PyObject *); +__PYX_EXTERN_C int hasText(xmlNode *); +__PYX_EXTERN_C int hasTail(xmlNode *); +__PYX_EXTERN_C PyObject *textOf(xmlNode *); +__PYX_EXTERN_C PyObject *tailOf(xmlNode *); +__PYX_EXTERN_C int setNodeText(xmlNode *, PyObject *); +__PYX_EXTERN_C int setTailText(xmlNode *, PyObject *); +__PYX_EXTERN_C PyObject *attributeValue(xmlNode *, xmlAttr *); +__PYX_EXTERN_C PyObject *attributeValueFromNsName(xmlNode *, const xmlChar *, const xmlChar *); +__PYX_EXTERN_C PyObject *getAttributeValue(struct LxmlElement *, PyObject *, PyObject *); +__PYX_EXTERN_C PyObject *iterattributes(struct LxmlElement *, int); +__PYX_EXTERN_C PyObject *collectAttributes(xmlNode *, int); +__PYX_EXTERN_C int setAttributeValue(struct LxmlElement *, PyObject *, PyObject *); +__PYX_EXTERN_C int delAttribute(struct LxmlElement *, PyObject *); +__PYX_EXTERN_C int delAttributeFromNsName(xmlNode *, const xmlChar *, const xmlChar *); +__PYX_EXTERN_C int hasChild(xmlNode *); +__PYX_EXTERN_C xmlNode *findChild(xmlNode *, Py_ssize_t); +__PYX_EXTERN_C xmlNode *findChildForwards(xmlNode *, Py_ssize_t); +__PYX_EXTERN_C xmlNode *findChildBackwards(xmlNode *, Py_ssize_t); +__PYX_EXTERN_C xmlNode *nextElement(xmlNode *); +__PYX_EXTERN_C xmlNode *previousElement(xmlNode *); +__PYX_EXTERN_C void appendChild(struct LxmlElement *, struct LxmlElement *); +__PYX_EXTERN_C int appendChildToElement(struct LxmlElement *, struct LxmlElement *); +__PYX_EXTERN_C PyObject *pyunicode(const xmlChar *); +__PYX_EXTERN_C PyObject *utf8(PyObject *); +__PYX_EXTERN_C PyObject *getNsTag(PyObject *); +__PYX_EXTERN_C PyObject *getNsTagWithEmptyNs(PyObject *); +__PYX_EXTERN_C PyObject *namespacedName(xmlNode *); +__PYX_EXTERN_C PyObject *namespacedNameFromNsName(const xmlChar *, const xmlChar *); +__PYX_EXTERN_C void iteratorStoreNext(struct LxmlElementIterator *, struct LxmlElement *); +__PYX_EXTERN_C void initTagMatch(struct LxmlElementTagMatcher *, PyObject *); +__PYX_EXTERN_C xmlNs *findOrBuildNodeNsPrefix(struct LxmlDocument *, xmlNode *, const xmlChar *, const xmlChar *); + +#endif /* !__PYX_HAVE_API__lxml__etree */ + +/* WARNING: the interface of the module init function changed in CPython 3.5. */ +/* It now returns a PyModuleDef instance instead of a PyModule instance. */ + +/* WARNING: Use PyImport_AppendInittab("etree", PyInit_etree) instead of calling PyInit_etree directly from Python 3.5 */ +PyMODINIT_FUNC PyInit_etree(void); + +#if PY_VERSION_HEX >= 0x03050000 && (defined(__GNUC__) || defined(__clang__) || defined(_MSC_VER) || (defined(__cplusplus) && __cplusplus >= 201402L)) +#if defined(__cplusplus) && __cplusplus >= 201402L +[[deprecated("Use PyImport_AppendInittab(\"etree\", PyInit_etree) instead of calling PyInit_etree directly.")]] inline +#elif defined(__GNUC__) || defined(__clang__) +__attribute__ ((__deprecated__("Use PyImport_AppendInittab(\"etree\", PyInit_etree) instead of calling PyInit_etree directly."), __unused__)) __inline__ +#elif defined(_MSC_VER) +__declspec(deprecated("Use PyImport_AppendInittab(\"etree\", PyInit_etree) instead of calling PyInit_etree directly.")) __inline +#endif +static PyObject* __PYX_WARN_IF_PyInit_etree_INIT_CALLED(PyObject* res) { + return res; +} +#define PyInit_etree() __PYX_WARN_IF_PyInit_etree_INIT_CALLED(PyInit_etree()) +#endif + +#endif /* !__PYX_HAVE__lxml__etree */ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/objectpath.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/objectpath.pxi new file mode 100644 index 0000000000000000000000000000000000000000..e562a365015830bfd3d24650d1109fe891c31039 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/objectpath.pxi @@ -0,0 +1,332 @@ +################################################################################ +# ObjectPath + +ctypedef struct _ObjectPath: + const_xmlChar* href + const_xmlChar* name + Py_ssize_t index + + +cdef object _NO_DEFAULT = object() + + +cdef class ObjectPath: + """ObjectPath(path) + Immutable object that represents a compiled object path. + + Example for a path: 'root.child[1].{other}child[25]' + """ + cdef readonly object find + cdef list _path + cdef object _path_str + cdef _ObjectPath* _c_path + cdef Py_ssize_t _path_len + def __init__(self, path): + if python._isString(path): + self._path = _parse_object_path_string(path) + self._path_str = path + else: + self._path = _parse_object_path_list(path) + self._path_str = '.'.join(path) + self._path_len = len(self._path) + self._c_path = _build_object_path_segments(self._path) + self.find = self.__call__ + + def __dealloc__(self): + if self._c_path is not NULL: + python.lxml_free(self._c_path) + + def __str__(self): + return self._path_str + + def __call__(self, _Element root not None, *_default): + """Follow the attribute path in the object structure and return the + target attribute value. + + If it it not found, either returns a default value (if one was passed + as second argument) or raises AttributeError. + """ + if _default: + if len(_default) > 1: + raise TypeError, "invalid number of arguments: needs one or two" + default = _default[0] + else: + default = _NO_DEFAULT + return _find_object_path(root, self._c_path, self._path_len, default) + + def hasattr(self, _Element root not None): + "hasattr(self, root)" + try: + _find_object_path(root, self._c_path, self._path_len, _NO_DEFAULT) + except AttributeError: + return False + return True + + def setattr(self, _Element root not None, value): + """setattr(self, root, value) + + Set the value of the target element in a subtree. + + If any of the children on the path does not exist, it is created. + """ + _create_object_path(root, self._c_path, self._path_len, 1, value) + + def addattr(self, _Element root not None, value): + """addattr(self, root, value) + + Append a value to the target element in a subtree. + + If any of the children on the path does not exist, it is created. + """ + _create_object_path(root, self._c_path, self._path_len, 0, value) + + +cdef object __MATCH_PATH_SEGMENT = re.compile( + r"(\.?)\s*(?:\{([^}]*)\})?\s*([^.{}\[\]\s]+)\s*(?:\[\s*([-0-9]+)\s*\])?", + re.U).match + +cdef tuple _RELATIVE_PATH_SEGMENT = (None, None, 0) + + +cdef list _parse_object_path_string(_path): + """Parse object path string into a (ns, name, index) list. + """ + cdef bint has_dot + cdef unicode path + new_path = [] + if isinstance(_path, bytes): + path = (_path).decode('ascii') + elif type(_path) is not unicode: + path = unicode(_path) + else: + path = _path + path = path.strip() + if path == '.': + return [_RELATIVE_PATH_SEGMENT] + path_pos = 0 + while path: + match = __MATCH_PATH_SEGMENT(path, path_pos) + if match is None: + break + + dot, ns, name, index = match.groups() + index = int(index) if index else 0 + has_dot = dot == '.' + if not new_path: + if has_dot: + # path '.child' => ignore root + new_path.append(_RELATIVE_PATH_SEGMENT) + elif index: + raise ValueError, "index not allowed on root node" + elif not has_dot: + raise ValueError, "invalid path" + if ns is not None: + ns = python.PyUnicode_AsUTF8String(ns) + name = python.PyUnicode_AsUTF8String(name) + new_path.append( (ns, name, index) ) + + path_pos = match.end() + if not new_path or len(path) > path_pos: + raise ValueError, "invalid path" + return new_path + + +cdef list _parse_object_path_list(path): + """Parse object path sequence into a (ns, name, index) list. + """ + new_path = [] + for item in path: + item = item.strip() + if not new_path and item == '': + # path '.child' => ignore root + ns = name = None + index = 0 + else: + ns, name = cetree.getNsTag(item) + c_name = _xcstr(name) + index_pos = tree.xmlStrchr(c_name, c'[') + if index_pos is NULL: + index = 0 + else: + index_end = tree.xmlStrchr(index_pos + 1, c']') + if index_end is NULL: + raise ValueError, "index must be enclosed in []" + index = int(index_pos[1:index_end - index_pos]) + if not new_path and index != 0: + raise ValueError, "index not allowed on root node" + name = c_name[:index_pos - c_name] + new_path.append( (ns, name, index) ) + if not new_path: + raise ValueError, "invalid path" + return new_path + + +cdef _ObjectPath* _build_object_path_segments(list path_list) except NULL: + cdef _ObjectPath* c_path + cdef _ObjectPath* c_path_segments + c_path_segments = <_ObjectPath*>python.lxml_malloc(len(path_list), sizeof(_ObjectPath)) + if c_path_segments is NULL: + raise MemoryError() + c_path = c_path_segments + for href, name, index in path_list: + c_path[0].href = _xcstr(href) if href is not None else NULL + c_path[0].name = _xcstr(name) if name is not None else NULL + c_path[0].index = index + c_path += 1 + return c_path_segments + + +cdef _find_object_path(_Element root, _ObjectPath* c_path, Py_ssize_t c_path_len, default_value): + """Follow the path to find the target element. + """ + cdef tree.xmlNode* c_node + cdef Py_ssize_t c_index + c_node = root._c_node + c_name = c_path[0].name + c_href = c_path[0].href + if c_href is NULL or c_href[0] == c'\0': + c_href = tree._getNs(c_node) + if not cetree.tagMatches(c_node, c_href, c_name): + if default_value is not _NO_DEFAULT: + return default_value + else: + raise ValueError( + f"root element does not match: need {cetree.namespacedNameFromNsName(c_href, c_name)}, got {root.tag}") + + while c_node is not NULL: + c_path_len -= 1 + if c_path_len <= 0: + break + + c_path += 1 + if c_path[0].href is not NULL: + c_href = c_path[0].href # otherwise: keep parent namespace + c_name = tree.xmlDictExists(c_node.doc.dict, c_path[0].name, -1) + if c_name is NULL: + c_name = c_path[0].name + c_node = NULL + break + c_index = c_path[0].index + c_node = c_node.last if c_index < 0 else c_node.children + c_node = _findFollowingSibling(c_node, c_href, c_name, c_index) + + if c_node is not NULL: + return cetree.elementFactory(root._doc, c_node) + elif default_value is not _NO_DEFAULT: + return default_value + else: + tag = cetree.namespacedNameFromNsName(c_href, c_name) + raise AttributeError, f"no such child: {tag}" + + +cdef _create_object_path(_Element root, _ObjectPath* c_path, + Py_ssize_t c_path_len, int replace, value): + """Follow the path to find the target element, build the missing children + as needed and set the target element to 'value'. If replace is true, an + existing value is replaced, otherwise the new value is added. + """ + cdef _Element child + cdef tree.xmlNode* c_node + cdef tree.xmlNode* c_child + cdef Py_ssize_t c_index + if c_path_len == 1: + raise TypeError, "cannot update root node" + + c_node = root._c_node + c_name = c_path[0].name + c_href = c_path[0].href + if c_href is NULL or c_href[0] == c'\0': + c_href = tree._getNs(c_node) + if not cetree.tagMatches(c_node, c_href, c_name): + raise ValueError( + f"root element does not match: need {cetree.namespacedNameFromNsName(c_href, c_name)}, got {root.tag}") + + while c_path_len > 1: + c_path_len -= 1 + c_path += 1 + if c_path[0].href is not NULL: + c_href = c_path[0].href # otherwise: keep parent namespace + c_index = c_path[0].index + c_name = tree.xmlDictExists(c_node.doc.dict, c_path[0].name, -1) + if c_name is NULL: + c_name = c_path[0].name + c_child = NULL + else: + c_child = c_node.last if c_index < 0 else c_node.children + c_child = _findFollowingSibling(c_child, c_href, c_name, c_index) + + if c_child is not NULL: + c_node = c_child + elif c_index != 0: + raise TypeError, "creating indexed path attributes is not supported" + elif c_path_len == 1: + _appendValue(cetree.elementFactory(root._doc, c_node), + cetree.namespacedNameFromNsName(c_href, c_name), + value) + return + else: + child = cetree.makeSubElement( + cetree.elementFactory(root._doc, c_node), + cetree.namespacedNameFromNsName(c_href, c_name), + None, None, None, None) + c_node = child._c_node + + # if we get here, the entire path was already there + if replace: + element = cetree.elementFactory(root._doc, c_node) + _replaceElement(element, value) + else: + _appendValue(cetree.elementFactory(root._doc, c_node.parent), + cetree.namespacedName(c_node), value) + + +cdef list _build_descendant_paths(tree.xmlNode* c_node, prefix_string): + """Returns a list of all descendant paths. + """ + cdef list path, path_list + tag = cetree.namespacedName(c_node) + if prefix_string: + if prefix_string[-1] != '.': + prefix_string += '.' + prefix_string = prefix_string + tag + else: + prefix_string = tag + path = [prefix_string] + path_list = [] + _recursive_build_descendant_paths(c_node, path, path_list) + return path_list + + +cdef int _recursive_build_descendant_paths(tree.xmlNode* c_node, + list path, list path_list) except -1: + """Fills the list 'path_list' with all descendant paths, initial prefix + being in the list 'path'. + """ + cdef tree.xmlNode* c_child + tags = {} + path_list.append('.'.join(path)) + c_href = tree._getNs(c_node) + c_child = c_node.children + while c_child is not NULL: + while c_child.type != tree.XML_ELEMENT_NODE: + c_child = c_child.next + if c_child is NULL: + return 0 + if c_href is tree._getNs(c_child): + tag = pyunicode(c_child.name) + elif c_href is not NULL and tree._getNs(c_child) is NULL: + # special case: parent has namespace, child does not + tag = '{}' + pyunicode(c_child.name) + else: + tag = cetree.namespacedName(c_child) + count = tags.get(tag) + if count is None: + tags[tag] = 1 + else: + tags[tag] = count + 1 + tag += f'[{count}]' + path.append(tag) + _recursive_build_descendant_paths(c_child, path, path_list) + del path[-1] + c_child = c_child.next + return 0 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/proxy.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/proxy.pxi new file mode 100644 index 0000000000000000000000000000000000000000..0e6cf19ef7ffadcb061aff55cdd6acb4d2ccfefa --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/proxy.pxi @@ -0,0 +1,622 @@ +# Proxy functions and low level node allocation stuff + +# Proxies represent elements, their reference is stored in the C +# structure of the respective node to avoid multiple instantiation of +# the Python class. + +@cython.linetrace(False) +@cython.profile(False) +cdef inline _Element getProxy(xmlNode* c_node): + """Get a proxy for a given node. + """ + #print "getProxy for:", c_node + if c_node is not NULL and c_node._private is not NULL: + return <_Element>c_node._private + else: + return None + + +@cython.linetrace(False) +@cython.profile(False) +cdef inline bint hasProxy(xmlNode* c_node): + if c_node._private is NULL: + return False + return True + + +@cython.linetrace(False) +@cython.profile(False) +cdef inline int _registerProxy(_Element proxy, _Document doc, + xmlNode* c_node) except -1: + """Register a proxy and type for the node it's proxying for. + """ + #print "registering for:", proxy._c_node + assert not hasProxy(c_node), "double registering proxy!" + proxy._doc = doc + proxy._c_node = c_node + c_node._private = proxy + return 0 + + +@cython.linetrace(False) +@cython.profile(False) +cdef inline int _unregisterProxy(_Element proxy) except -1: + """Unregister a proxy for the node it's proxying for. + """ + cdef xmlNode* c_node = proxy._c_node + assert c_node._private is proxy, "Tried to unregister unknown proxy" + c_node._private = NULL + return 0 + + +################################################################################ +# temporarily make a node the root node of its document + +cdef xmlDoc* _fakeRootDoc(xmlDoc* c_base_doc, xmlNode* c_node) except NULL: + return _plainFakeRootDoc(c_base_doc, c_node, 1) + +cdef xmlDoc* _plainFakeRootDoc(xmlDoc* c_base_doc, xmlNode* c_node, + bint with_siblings) except NULL: + # build a temporary document that has the given node as root node + # note that copy and original must not be modified during its lifetime!! + # always call _destroyFakeDoc() after use! + cdef xmlNode* c_child + cdef xmlNode* c_root + cdef xmlNode* c_new_root + cdef xmlDoc* c_doc + if with_siblings or (c_node.prev is NULL and c_node.next is NULL): + c_root = tree.xmlDocGetRootElement(c_base_doc) + if c_root is c_node: + # already the root node, no siblings + return c_base_doc + + c_doc = _copyDoc(c_base_doc, 0) # non recursive! + c_new_root = tree.xmlDocCopyNode(c_node, c_doc, 2) # non recursive! + tree.xmlDocSetRootElement(c_doc, c_new_root) + _copyParentNamespaces(c_node, c_new_root) + + c_new_root.children = c_node.children + c_new_root.last = c_node.last + c_new_root.next = c_new_root.prev = NULL + + # store original node + c_doc._private = c_node + + # divert parent pointers of children + c_child = c_new_root.children + while c_child is not NULL: + c_child.parent = c_new_root + c_child = c_child.next + + c_doc.children = c_new_root + return c_doc + +cdef void _destroyFakeDoc(xmlDoc* c_base_doc, xmlDoc* c_doc) noexcept: + # delete a temporary document + cdef xmlNode* c_child + cdef xmlNode* c_parent + cdef xmlNode* c_root + if c_doc is c_base_doc: + return + c_root = tree.xmlDocGetRootElement(c_doc) + + # restore parent pointers of children + c_parent = c_doc._private + c_child = c_root.children + while c_child is not NULL: + c_child.parent = c_parent + c_child = c_child.next + + # prevent recursive removal of children + c_root.children = c_root.last = NULL + tree.xmlFreeDoc(c_doc) + +cdef _Element _fakeDocElementFactory(_Document doc, xmlNode* c_element): + """Special element factory for cases where we need to create a fake + root document, but still need to instantiate arbitrary nodes from + it. If we instantiate the fake root node, things will turn bad + when it's destroyed. + + Instead, if we are asked to instantiate the fake root node, we + instantiate the original node instead. + """ + if c_element.doc is not doc._c_doc: + if c_element.doc._private is not NULL: + if c_element is c_element.doc.children: + c_element = c_element.doc._private + #assert c_element.type == tree.XML_ELEMENT_NODE + return _elementFactory(doc, c_element) + +################################################################################ +# support for freeing tree elements when proxy objects are destroyed + +cdef int attemptDeallocation(xmlNode* c_node) noexcept: + """Attempt deallocation of c_node (or higher up in tree). + """ + cdef xmlNode* c_top + # could be we actually aren't referring to the tree at all + if c_node is NULL: + #print "not freeing, node is NULL" + return 0 + c_top = getDeallocationTop(c_node) + if c_top is not NULL: + #print "freeing:", c_top.name + _removeText(c_top.next) # tail + tree.xmlFreeNode(c_top) + return 1 + return 0 + +cdef xmlNode* getDeallocationTop(xmlNode* c_node) noexcept: + """Return the top of the tree that can be deallocated, or NULL. + """ + cdef xmlNode* c_next + #print "trying to do deallocating:", c_node.type + if hasProxy(c_node): + #print "Not freeing: proxies still exist" + return NULL + while c_node.parent is not NULL: + c_node = c_node.parent + #print "checking:", c_current.type + if c_node.type == tree.XML_DOCUMENT_NODE or \ + c_node.type == tree.XML_HTML_DOCUMENT_NODE: + #print "not freeing: still in doc" + return NULL + # if we're still attached to the document, don't deallocate + if hasProxy(c_node): + #print "Not freeing: proxies still exist" + return NULL + # see whether we have children to deallocate + if not canDeallocateChildNodes(c_node): + return NULL + # see whether we have siblings to deallocate + c_next = c_node.prev + while c_next: + if _isElement(c_next): + if hasProxy(c_next) or not canDeallocateChildNodes(c_next): + return NULL + c_next = c_next.prev + c_next = c_node.next + while c_next: + if _isElement(c_next): + if hasProxy(c_next) or not canDeallocateChildNodes(c_next): + return NULL + c_next = c_next.next + return c_node + +cdef int canDeallocateChildNodes(xmlNode* c_parent) noexcept: + cdef xmlNode* c_node + c_node = c_parent.children + tree.BEGIN_FOR_EACH_ELEMENT_FROM(c_parent, c_node, 1) + if hasProxy(c_node): + return 0 + tree.END_FOR_EACH_ELEMENT_FROM(c_node) + return 1 + +################################################################################ +# fix _Document references and namespaces when a node changes documents + +cdef void _copyParentNamespaces(xmlNode* c_from_node, xmlNode* c_to_node) noexcept nogil: + """Copy the namespaces of all ancestors of c_from_node to c_to_node. + """ + cdef xmlNode* c_parent + cdef xmlNs* c_ns + cdef xmlNs* c_new_ns + cdef int prefix_known + c_parent = c_from_node.parent + while c_parent and (tree._isElementOrXInclude(c_parent) or + c_parent.type == tree.XML_DOCUMENT_NODE): + c_new_ns = c_parent.nsDef + while c_new_ns: + # libxml2 will check if the prefix is already defined + tree.xmlNewNs(c_to_node, c_new_ns.href, c_new_ns.prefix) + c_new_ns = c_new_ns.next + c_parent = c_parent.parent + + +ctypedef struct _ns_update_map: + xmlNs* old + xmlNs* new + + +ctypedef struct _nscache: + _ns_update_map* ns_map + size_t size + size_t last + + +cdef int _growNsCache(_nscache* c_ns_cache) except -1: + cdef _ns_update_map* ns_map_ptr + if c_ns_cache.size == 0: + c_ns_cache.size = 20 + else: + c_ns_cache.size *= 2 + ns_map_ptr = <_ns_update_map*> python.lxml_realloc( + c_ns_cache.ns_map, c_ns_cache.size, sizeof(_ns_update_map)) + if not ns_map_ptr: + python.lxml_free(c_ns_cache.ns_map) + c_ns_cache.ns_map = NULL + raise MemoryError() + c_ns_cache.ns_map = ns_map_ptr + return 0 + + +cdef inline int _appendToNsCache(_nscache* c_ns_cache, + xmlNs* c_old_ns, xmlNs* c_new_ns) except -1: + if c_ns_cache.last >= c_ns_cache.size: + _growNsCache(c_ns_cache) + c_ns_cache.ns_map[c_ns_cache.last] = _ns_update_map(old=c_old_ns, new=c_new_ns) + c_ns_cache.last += 1 + + +cdef int _stripRedundantNamespaceDeclarations(xmlNode* c_element, _nscache* c_ns_cache, + xmlNs** c_del_ns_list) except -1: + """Removes namespace declarations from an element that are already + defined in its parents. Does not free the xmlNs's, just prepends + them to the c_del_ns_list. + """ + cdef xmlNs* c_ns + cdef xmlNs* c_ns_next + cdef xmlNs** c_nsdef + # use a xmlNs** to handle assignments to "c_element.nsDef" correctly + c_nsdef = &c_element.nsDef + while c_nsdef[0] is not NULL: + c_ns = tree.xmlSearchNsByHref( + c_element.doc, c_element.parent, c_nsdef[0].href) + if c_ns is NULL: + # new namespace href => keep and cache the ns declaration + _appendToNsCache(c_ns_cache, c_nsdef[0], c_nsdef[0]) + c_nsdef = &c_nsdef[0].next + else: + # known namespace href => cache mapping and strip old ns + _appendToNsCache(c_ns_cache, c_nsdef[0], c_ns) + # cut out c_nsdef.next and prepend it to garbage chain + c_ns_next = c_nsdef[0].next + c_nsdef[0].next = c_del_ns_list[0] + c_del_ns_list[0] = c_nsdef[0] + c_nsdef[0] = c_ns_next + return 0 + + +cdef void _cleanUpFromNamespaceAdaptation(xmlNode* c_start_node, + _nscache* c_ns_cache, xmlNs* c_del_ns_list) noexcept: + # Try to recover from exceptions with really bad timing. We were in the middle + # of ripping out xmlNS-es and likely ran out of memory. Try to fix up the tree + # by re-adding the original xmlNs declarations (which might still be used in some + # places). + if c_ns_cache.ns_map: + python.lxml_free(c_ns_cache.ns_map) + if c_del_ns_list: + if not c_start_node.nsDef: + c_start_node.nsDef = c_del_ns_list + else: + c_ns = c_start_node.nsDef + while c_ns.next: + c_ns = c_ns.next + c_ns.next = c_del_ns_list + + +cdef int moveNodeToDocument(_Document doc, xmlDoc* c_source_doc, + xmlNode* c_element) except -1: + """Fix the xmlNs pointers of a node and its subtree that were moved. + + Originally copied from libxml2's xmlReconciliateNs(). Expects + libxml2 doc pointers of node to be correct already, but fixes + _Document references. + + For each node in the subtree, we do this: + + 1) Remove redundant declarations of namespace that are already + defined in its parents. + + 2) Replace namespaces that are *not* defined on the node or its + parents by the equivalent namespace declarations that *are* + defined on the node or its parents (possibly using a different + prefix). If a namespace is unknown, declare a new one on the + node. + + 3) Reassign the names of tags and attribute from the dict of the + target document *iff* it is different from the dict used in the + source subtree. + + 4) Set the Document reference to the new Document (if different). + This is done on backtracking to keep the original Document + alive as long as possible, until all its elements are updated. + + Note that the namespace declarations are removed from the tree in + step 1), but freed only after the complete subtree was traversed + and all occurrences were replaced by tree-internal pointers. + """ + cdef xmlNode* c_start_node + cdef xmlNode* c_node + cdef xmlDoc* c_doc = doc._c_doc + cdef tree.xmlAttr* c_attr + cdef char* c_name + cdef _nscache c_ns_cache = [NULL, 0, 0] + cdef xmlNs* c_del_ns_list = NULL + cdef proxy_count = 0 + + if not tree._isElementOrXInclude(c_element): + return 0 + + c_start_node = c_element + + tree.BEGIN_FOR_EACH_FROM(c_element, c_element, 1) + if tree._isElementOrXInclude(c_element): + if hasProxy(c_element): + proxy_count += 1 + + # 1) cut out namespaces defined here that are already known by + # the ancestors + if c_element.nsDef is not NULL: + try: + _stripRedundantNamespaceDeclarations(c_element, &c_ns_cache, &c_del_ns_list) + except: + _cleanUpFromNamespaceAdaptation(c_start_node, &c_ns_cache, c_del_ns_list) + raise + + # 2) make sure the namespaces of an element and its attributes + # are declared in this document (i.e. on the node or its parents) + if c_element.ns is not NULL: + _fixCNs(doc, c_start_node, c_element, &c_ns_cache, c_del_ns_list) + + c_node = c_element.properties + while c_node is not NULL: + if c_node.ns is not NULL: + _fixCNs(doc, c_start_node, c_node, &c_ns_cache, c_del_ns_list) + c_node = c_node.next + + tree.END_FOR_EACH_FROM(c_element) + + # free now unused namespace declarations + if c_del_ns_list is not NULL: + tree.xmlFreeNsList(c_del_ns_list) + + # cleanup + if c_ns_cache.ns_map is not NULL: + python.lxml_free(c_ns_cache.ns_map) + + # 3) fix the names in the tree if we moved it from a different thread + if doc._c_doc.dict is not c_source_doc.dict: + fixThreadDictNames(c_start_node, c_source_doc.dict, doc._c_doc.dict) + + # 4) fix _Document references + # (and potentially deallocate the source document) + if proxy_count > 0: + if proxy_count == 1 and c_start_node._private is not NULL: + proxy = getProxy(c_start_node) + if proxy is not None: + if proxy._doc is not doc: + proxy._doc = doc + else: + fixElementDocument(c_start_node, doc, proxy_count) + else: + fixElementDocument(c_start_node, doc, proxy_count) + + return 0 + + +cdef void _setTreeDoc(xmlNode* c_node, xmlDoc* c_doc) noexcept: + """Adaptation of 'xmlSetTreeDoc()' that deep-fixes the document links iteratively. + It avoids https://gitlab.gnome.org/GNOME/libxml2/issues/42 + """ + tree.BEGIN_FOR_EACH_FROM(c_node, c_node, 1) + if c_node.type == tree.XML_ELEMENT_NODE: + c_attr = c_node.properties + while c_attr: + if c_attr.atype == tree.XML_ATTRIBUTE_ID: + tree.xmlRemoveID(c_node.doc, c_attr) + c_attr.doc = c_doc + _fixDocChildren(c_attr.children, c_doc) + c_attr = c_attr.next + # Set doc link for all nodes, not only elements. + c_node.doc = c_doc + tree.END_FOR_EACH_FROM(c_node) + + +cdef inline void _fixDocChildren(xmlNode* c_child, xmlDoc* c_doc) noexcept: + while c_child: + c_child.doc = c_doc + if c_child.children: + _fixDocChildren(c_child.children, c_doc) + c_child = c_child.next + + +cdef int _fixCNs(_Document doc, xmlNode* c_start_node, xmlNode* c_node, + _nscache* c_ns_cache, xmlNs* c_del_ns_list) except -1: + cdef xmlNs* c_ns = NULL + cdef bint is_prefixed_attr = (c_node.type == tree.XML_ATTRIBUTE_NODE and c_node.ns.prefix) + + for ns_map in c_ns_cache.ns_map[:c_ns_cache.last]: + if c_node.ns is ns_map.old: + if is_prefixed_attr and not ns_map.new.prefix: + # avoid dropping prefix from attributes + continue + c_ns = ns_map.new + break + + if c_ns: + c_node.ns = c_ns + else: + # not in cache or not acceptable + # => find a replacement from this document + try: + c_ns = doc._findOrBuildNodeNs( + c_start_node, c_node.ns.href, c_node.ns.prefix, + c_node.type == tree.XML_ATTRIBUTE_NODE) + c_node.ns = c_ns + _appendToNsCache(c_ns_cache, c_node.ns, c_ns) + except: + _cleanUpFromNamespaceAdaptation(c_start_node, c_ns_cache, c_del_ns_list) + raise + return 0 + + +cdef int fixElementDocument(xmlNode* c_element, _Document doc, + size_t proxy_count) except -1: + cdef xmlNode* c_node = c_element + cdef _Element proxy = None # init-to-None required due to fake-loop below + tree.BEGIN_FOR_EACH_FROM(c_element, c_node, 1) + if c_node._private is not NULL: + proxy = getProxy(c_node) + if proxy is not None: + if proxy._doc is not doc: + proxy._doc = doc + proxy_count -= 1 + if proxy_count == 0: + return 0 + tree.END_FOR_EACH_FROM(c_node) + + +cdef void fixThreadDictNames(xmlNode* c_element, + tree.xmlDict* c_src_dict, + tree.xmlDict* c_dict) noexcept nogil: + # re-assign the names of tags and attributes + # + # this should only be called when the element is based on a + # different libxml2 tag name dictionary + if c_element.type == tree.XML_DOCUMENT_NODE or \ + c_element.type == tree.XML_HTML_DOCUMENT_NODE: + # may define "xml" namespace + fixThreadDictNsForNode(c_element, c_src_dict, c_dict) + if c_element.doc.extSubset: + fixThreadDictNamesForDtd(c_element.doc.extSubset, c_src_dict, c_dict) + if c_element.doc.intSubset: + fixThreadDictNamesForDtd(c_element.doc.intSubset, c_src_dict, c_dict) + c_element = c_element.children + while c_element is not NULL: + fixThreadDictNamesForNode(c_element, c_src_dict, c_dict) + c_element = c_element.next + elif tree._isElementOrXInclude(c_element): + fixThreadDictNamesForNode(c_element, c_src_dict, c_dict) + + +cdef inline void _fixThreadDictPtr(const_xmlChar** c_ptr, + tree.xmlDict* c_src_dict, + tree.xmlDict* c_dict) noexcept nogil: + c_str = c_ptr[0] + if c_str and c_src_dict and tree.xmlDictOwns(c_src_dict, c_str): + # return value can be NULL on memory error, but we don't handle that here + c_str = tree.xmlDictLookup(c_dict, c_str, -1) + if c_str: + c_ptr[0] = c_str + + +cdef void fixThreadDictNamesForNode(xmlNode* c_element, + tree.xmlDict* c_src_dict, + tree.xmlDict* c_dict) noexcept nogil: + cdef xmlNode* c_node = c_element + tree.BEGIN_FOR_EACH_FROM(c_element, c_node, 1) + if c_node.type in (tree.XML_ELEMENT_NODE, tree.XML_XINCLUDE_START): + fixThreadDictNamesForAttributes( + c_node.properties, c_src_dict, c_dict) + fixThreadDictNsForNode(c_node, c_src_dict, c_dict) + _fixThreadDictPtr(&c_node.name, c_src_dict, c_dict) + elif c_node.type == tree.XML_TEXT_NODE: + # libxml2's SAX2 parser interns some indentation space + fixThreadDictContentForNode(c_node, c_src_dict, c_dict) + elif c_node.type == tree.XML_COMMENT_NODE: + pass # don't touch c_node.name + else: + _fixThreadDictPtr(&c_node.name, c_src_dict, c_dict) + tree.END_FOR_EACH_FROM(c_node) + + +cdef inline void fixThreadDictNamesForAttributes(tree.xmlAttr* c_attr, + tree.xmlDict* c_src_dict, + tree.xmlDict* c_dict) noexcept nogil: + cdef xmlNode* c_child + cdef xmlNode* c_node = c_attr + while c_node is not NULL: + if c_node.type not in (tree.XML_TEXT_NODE, tree.XML_COMMENT_NODE): + _fixThreadDictPtr(&c_node.name, c_src_dict, c_dict) + # libxml2 keeps some (!) attribute values in the dict + c_child = c_node.children + while c_child is not NULL: + fixThreadDictContentForNode(c_child, c_src_dict, c_dict) + c_child = c_child.next + c_node = c_node.next + + +cdef inline void fixThreadDictContentForNode(xmlNode* c_node, + tree.xmlDict* c_src_dict, + tree.xmlDict* c_dict) noexcept nogil: + if c_node.content is not NULL and \ + c_node.content is not &c_node.properties: + if tree.xmlDictOwns(c_src_dict, c_node.content): + # result can be NULL on memory error, but we don't handle that here + c_node.content = tree.xmlDictLookup(c_dict, c_node.content, -1) + + +cdef inline void fixThreadDictNsForNode(xmlNode* c_node, + tree.xmlDict* c_src_dict, + tree.xmlDict* c_dict) noexcept nogil: + cdef xmlNs* c_ns = c_node.nsDef + while c_ns is not NULL: + _fixThreadDictPtr(&c_ns.href, c_src_dict, c_dict) + _fixThreadDictPtr(&c_ns.prefix, c_src_dict, c_dict) + c_ns = c_ns.next + + +cdef void fixThreadDictNamesForDtd(tree.xmlDtd* c_dtd, + tree.xmlDict* c_src_dict, + tree.xmlDict* c_dict) noexcept nogil: + cdef xmlNode* c_node + cdef tree.xmlElement* c_element + cdef tree.xmlAttribute* c_attribute + cdef tree.xmlEntity* c_entity + + c_node = c_dtd.children + while c_node: + if c_node.type == tree.XML_ELEMENT_DECL: + c_element = c_node + if c_element.content: + _fixThreadDictPtr(&c_element.content.name, c_src_dict, c_dict) + _fixThreadDictPtr(&c_element.content.prefix, c_src_dict, c_dict) + c_attribute = c_element.attributes + while c_attribute: + if tree.LIBXML_VERSION < 21500: + # libxml2 2.15 no longer stores default values in the dict. + # See https://gitlab.gnome.org/GNOME/libxml2/-/commit/24628f25 + _fixThreadDictPtr(&c_attribute.defaultValue, c_src_dict, c_dict) + _fixThreadDictPtr(&c_attribute.name, c_src_dict, c_dict) + _fixThreadDictPtr(&c_attribute.prefix, c_src_dict, c_dict) + _fixThreadDictPtr(&c_attribute.elem, c_src_dict, c_dict) + c_attribute = c_attribute.nexth + elif c_node.type == tree.XML_ENTITY_DECL: + c_entity = c_node + _fixThreadDictPtr(&c_entity.name, c_src_dict, c_dict) + _fixThreadDictPtr(&c_entity.ExternalID, c_src_dict, c_dict) + _fixThreadDictPtr(&c_entity.SystemID, c_src_dict, c_dict) + _fixThreadDictPtr(&c_entity.content, c_src_dict, c_dict) + c_node = c_node.next + + +################################################################################ +# adopt an xmlDoc from an external libxml2 document source + +cdef _Document _adoptForeignDoc(xmlDoc* c_doc, _BaseParser parser=None, bint is_owned=True): + """Convert and wrap an externally produced xmlDoc for use in lxml. + Assures that all '_private' pointers are NULL to prevent accidental + dereference into lxml proxy objects. + """ + if c_doc is NULL: + raise ValueError("Illegal document provided: NULL") + if c_doc.type not in (tree.XML_DOCUMENT_NODE, tree.XML_HTML_DOCUMENT_NODE): + doc_type = c_doc.type + if is_owned: + tree.xmlFreeDoc(c_doc) + raise ValueError(f"Illegal document provided: expected XML or HTML, found {doc_type}") + + cdef xmlNode* c_node = c_doc + + if is_owned: + tree.BEGIN_FOR_EACH_FROM(c_doc, c_node, 1) + c_node._private = NULL + tree.END_FOR_EACH_FROM(c_node) + else: + # create a fresh copy that lxml owns + c_doc = tree.xmlCopyDoc(c_doc, 1) + if c_doc is NULL: + raise MemoryError() + + return _documentFactory(c_doc, parser) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/public-api.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/public-api.pxi new file mode 100644 index 0000000000000000000000000000000000000000..fb8b2a2ced120b69c311270adba08924d65980a6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/public-api.pxi @@ -0,0 +1,178 @@ +# Public C API for lxml.etree + +cdef public api _Element deepcopyNodeToDocument(_Document doc, xmlNode* c_root): + "Recursively copy the element into the document. doc is not modified." + cdef xmlNode* c_node + c_node = _copyNodeToDoc(c_root, doc._c_doc) + return _elementFactory(doc, c_node) + +cdef public api _ElementTree elementTreeFactory(_Element context_node): + _assertValidNode(context_node) + return newElementTree(context_node, _ElementTree) + +cdef public api _ElementTree newElementTree(_Element context_node, + object subclass): + if context_node is NULL or context_node is None: + raise TypeError + _assertValidNode(context_node) + return _newElementTree(context_node._doc, context_node, subclass) + +cdef public api _ElementTree adoptExternalDocument(xmlDoc* c_doc, parser, bint is_owned): + if c_doc is NULL: + raise TypeError + doc = _adoptForeignDoc(c_doc, parser, is_owned) + return _elementTreeFactory(doc, None) + +cdef public api _Element elementFactory(_Document doc, xmlNode* c_node): + if c_node is NULL or doc is None: + raise TypeError + return _elementFactory(doc, c_node) + +cdef public api _Element makeElement(tag, _Document doc, parser, + text, tail, attrib, nsmap): + return _makeElement(tag, NULL, doc, parser, text, tail, attrib, nsmap, None) + +cdef public api _Element makeSubElement(_Element parent, tag, text, tail, + attrib, nsmap): + _assertValidNode(parent) + return _makeSubElement(parent, tag, text, tail, attrib, nsmap, None) + +cdef public api void setElementClassLookupFunction( + _element_class_lookup_function function, state): + _setElementClassLookupFunction(function, state) + +cdef public api object lookupDefaultElementClass(state, doc, xmlNode* c_node): + return _lookupDefaultElementClass(state, doc, c_node) + +cdef public api object lookupNamespaceElementClass(state, doc, xmlNode* c_node): + return _find_nselement_class(state, doc, c_node) + +cdef public api object callLookupFallback(FallbackElementClassLookup lookup, + _Document doc, xmlNode* c_node): + return _callLookupFallback(lookup, doc, c_node) + +cdef public api int tagMatches(xmlNode* c_node, const_xmlChar* c_href, const_xmlChar* c_name): + if c_node is NULL: + return -1 + return _tagMatches(c_node, c_href, c_name) + +cdef public api _Document documentOrRaise(object input): + return _documentOrRaise(input) + +cdef public api _Element rootNodeOrRaise(object input): + return _rootNodeOrRaise(input) + +cdef public api bint hasText(xmlNode* c_node): + return _hasText(c_node) + +cdef public api bint hasTail(xmlNode* c_node): + return _hasTail(c_node) + +cdef public api unicode textOf(xmlNode* c_node): + if c_node is NULL: + return None + return _collectText(c_node.children) + +cdef public api unicode tailOf(xmlNode* c_node): + if c_node is NULL: + return None + return _collectText(c_node.next) + +cdef public api int setNodeText(xmlNode* c_node, text) except -1: + if c_node is NULL: + raise ValueError + return _setNodeText(c_node, text) + +cdef public api int setTailText(xmlNode* c_node, text) except -1: + if c_node is NULL: + raise ValueError + return _setTailText(c_node, text) + +cdef public api unicode attributeValue(xmlNode* c_element, xmlAttr* c_attrib_node): + return _attributeValue(c_element, c_attrib_node) + +cdef public api unicode attributeValueFromNsName(xmlNode* c_element, + const_xmlChar* ns, const_xmlChar* name): + return _attributeValueFromNsName(c_element, ns, name) + +cdef public api object getAttributeValue(_Element element, key, default): + _assertValidNode(element) + return _getAttributeValue(element, key, default) + +cdef public api object iterattributes(_Element element, int keysvalues): + _assertValidNode(element) + return _attributeIteratorFactory(element, keysvalues) + +cdef public api list collectAttributes(xmlNode* c_element, int keysvalues): + return _collectAttributes(c_element, keysvalues) + +cdef public api int setAttributeValue(_Element element, key, value) except -1: + _assertValidNode(element) + return _setAttributeValue(element, key, value) + +cdef public api int delAttribute(_Element element, key) except -1: + _assertValidNode(element) + return _delAttribute(element, key) + +cdef public api int delAttributeFromNsName(tree.xmlNode* c_element, + const_xmlChar* c_href, const_xmlChar* c_name): + return _delAttributeFromNsName(c_element, c_href, c_name) + +cdef public api bint hasChild(xmlNode* c_node): + return _hasChild(c_node) + +cdef public api xmlNode* findChild(xmlNode* c_node, Py_ssize_t index): + return _findChild(c_node, index) + +cdef public api xmlNode* findChildForwards(xmlNode* c_node, Py_ssize_t index): + return _findChildForwards(c_node, index) + +cdef public api xmlNode* findChildBackwards(xmlNode* c_node, Py_ssize_t index): + return _findChildBackwards(c_node, index) + +cdef public api xmlNode* nextElement(xmlNode* c_node): + return _nextElement(c_node) + +cdef public api xmlNode* previousElement(xmlNode* c_node): + return _previousElement(c_node) + +cdef public api void appendChild(_Element parent, _Element child): + # deprecated, use appendChildToElement() instead! + _appendChild(parent, child) + +cdef public api int appendChildToElement(_Element parent, _Element child) except -1: + return _appendChild(parent, child) + +cdef public api unicode pyunicode(const_xmlChar* s): + if s is NULL: + raise TypeError + return funicode(s) + +cdef public api bytes utf8(object s): + return _utf8(s) + +cdef public api tuple getNsTag(object tag): + return _getNsTag(tag) + +cdef public api tuple getNsTagWithEmptyNs(object tag): + return _getNsTagWithEmptyNs(tag) + +cdef public api unicode namespacedName(xmlNode* c_node): + return _namespacedName(c_node) + +cdef public api unicode namespacedNameFromNsName(const_xmlChar* href, const_xmlChar* name): + return _namespacedNameFromNsName(href, name) + +cdef public api void iteratorStoreNext(_ElementIterator iterator, _Element node): + # deprecated! + iterator._storeNext(node) + +cdef public api void initTagMatch(_ElementTagMatcher matcher, tag): + # deprecated! + matcher._initTagMatch(tag) + +cdef public api tree.xmlNs* findOrBuildNodeNsPrefix( + _Document doc, xmlNode* c_node, const_xmlChar* href, const_xmlChar* prefix) except NULL: + if doc is None: + raise TypeError + return doc._findOrBuildNodeNs(c_node, href, prefix, 0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/pyclasslookup.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/pyclasslookup.py new file mode 100644 index 0000000000000000000000000000000000000000..9e1496dfb762108154a0c6c321a5e8fcf73de909 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/pyclasslookup.py @@ -0,0 +1,3 @@ +# dummy module for backwards compatibility + +from lxml.etree import PythonElementClassLookup diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/relaxng.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/relaxng.pxi new file mode 100644 index 0000000000000000000000000000000000000000..35f875891f7e59a785518b8b70bd19ef3f0f6099 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/relaxng.pxi @@ -0,0 +1,165 @@ +# support for RelaxNG validation +from lxml.includes cimport relaxng + +cdef object _rnc2rng +try: + import rnc2rng as _rnc2rng +except ImportError: + _rnc2rng = None + + +cdef int _require_rnc2rng() except -1: + if _rnc2rng is None: + raise RelaxNGParseError( + 'compact syntax not supported (please install rnc2rng)') + return 0 + + +cdef class RelaxNGError(LxmlError): + """Base class for RelaxNG errors. + """ + +cdef class RelaxNGParseError(RelaxNGError): + """Error while parsing an XML document as RelaxNG. + """ + +cdef class RelaxNGValidateError(RelaxNGError): + """Error while validating an XML document with a RelaxNG schema. + """ + + +################################################################################ +# RelaxNG + +cdef class RelaxNG(_Validator): + """RelaxNG(self, etree=None, file=None) + Turn a document into a Relax NG validator. + + Either pass a schema as Element or ElementTree, or pass a file or + filename through the ``file`` keyword argument. + """ + cdef relaxng.xmlRelaxNG* _c_schema + def __cinit__(self): + self._c_schema = NULL + + def __init__(self, etree=None, *, file=None): + cdef _Document doc + cdef _Element root_node + cdef xmlDoc* fake_c_doc = NULL + cdef relaxng.xmlRelaxNGParserCtxt* parser_ctxt = NULL + _Validator.__init__(self) + if etree is not None: + doc = _documentOrRaise(etree) + root_node = _rootNodeOrRaise(etree) + fake_c_doc = _fakeRootDoc(doc._c_doc, root_node._c_node) + parser_ctxt = relaxng.xmlRelaxNGNewDocParserCtxt(fake_c_doc) + elif file is not None: + if _isString(file): + if file[-4:].lower() == '.rnc': + _require_rnc2rng() + rng_data_utf8 = _utf8(_rnc2rng.dumps(_rnc2rng.load(file))) + doc = _parseMemoryDocument(rng_data_utf8, parser=None, url=file) + parser_ctxt = relaxng.xmlRelaxNGNewDocParserCtxt(doc._c_doc) + else: + doc = None + filename = _encodeFilename(file) + with self._error_log: + orig_loader = _register_document_loader() + parser_ctxt = relaxng.xmlRelaxNGNewParserCtxt(_cstr(filename)) + _reset_document_loader(orig_loader) + elif (_getFilenameForFile(file) or '')[-4:].lower() == '.rnc': + _require_rnc2rng() + rng_data_utf8 = _utf8(_rnc2rng.dumps(_rnc2rng.load(file))) + doc = _parseMemoryDocument( + rng_data_utf8, parser=None, url=_getFilenameForFile(file)) + parser_ctxt = relaxng.xmlRelaxNGNewDocParserCtxt(doc._c_doc) + else: + doc = _parseDocument(file, parser=None, base_url=None) + parser_ctxt = relaxng.xmlRelaxNGNewDocParserCtxt(doc._c_doc) + else: + raise RelaxNGParseError, "No tree or file given" + + if parser_ctxt is NULL: + if fake_c_doc is not NULL: + _destroyFakeDoc(doc._c_doc, fake_c_doc) + raise RelaxNGParseError( + self._error_log._buildExceptionMessage( + "Document is not parsable as Relax NG"), + self._error_log) + + # Need a cast here because older libxml2 releases do not use 'const' in the functype. + relaxng.xmlRelaxNGSetParserStructuredErrors( + parser_ctxt, _receiveError, self._error_log) + _connectGenericErrorLog(self._error_log, xmlerror.XML_FROM_RELAXNGP) + self._c_schema = relaxng.xmlRelaxNGParse(parser_ctxt) + _connectGenericErrorLog(None) + + relaxng.xmlRelaxNGFreeParserCtxt(parser_ctxt) + if self._c_schema is NULL: + if fake_c_doc is not NULL: + _destroyFakeDoc(doc._c_doc, fake_c_doc) + raise RelaxNGParseError( + self._error_log._buildExceptionMessage( + "Document is not valid Relax NG"), + self._error_log) + if fake_c_doc is not NULL: + _destroyFakeDoc(doc._c_doc, fake_c_doc) + + def __dealloc__(self): + relaxng.xmlRelaxNGFree(self._c_schema) + + def __call__(self, etree): + """__call__(self, etree) + + Validate doc using Relax NG. + + Returns true if document is valid, false if not.""" + cdef _Document doc + cdef _Element root_node + cdef xmlDoc* c_doc + cdef relaxng.xmlRelaxNGValidCtxt* valid_ctxt + cdef int ret + + assert self._c_schema is not NULL, "RelaxNG instance not initialised" + doc = _documentOrRaise(etree) + root_node = _rootNodeOrRaise(etree) + + valid_ctxt = relaxng.xmlRelaxNGNewValidCtxt(self._c_schema) + if valid_ctxt is NULL: + raise MemoryError() + + try: + self._error_log.clear() + # Need a cast here because older libxml2 releases do not use 'const' in the functype. + relaxng.xmlRelaxNGSetValidStructuredErrors( + valid_ctxt, _receiveError, self._error_log) + _connectGenericErrorLog(self._error_log, xmlerror.XML_FROM_RELAXNGV) + c_doc = _fakeRootDoc(doc._c_doc, root_node._c_node) + with nogil: + ret = relaxng.xmlRelaxNGValidateDoc(valid_ctxt, c_doc) + _destroyFakeDoc(doc._c_doc, c_doc) + finally: + _connectGenericErrorLog(None) + relaxng.xmlRelaxNGFreeValidCtxt(valid_ctxt) + + if ret == -1: + raise RelaxNGValidateError( + "Internal error in Relax NG validation", + self._error_log) + if ret == 0: + return True + else: + return False + + @classmethod + def from_rnc_string(cls, src, base_url=None): + """Parse a RelaxNG schema in compact syntax from a text string + + Requires the rnc2rng package to be installed. + + Passing the source URL or file path of the source as 'base_url' + will enable resolving resource references relative to the source. + """ + _require_rnc2rng() + rng_str = utf8(_rnc2rng.dumps(_rnc2rng.loads(src))) + return cls(_parseMemoryDocument(rng_str, parser=None, url=base_url)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/sax.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/sax.py new file mode 100644 index 0000000000000000000000000000000000000000..db77f6f29705e0776471c013abd0e1f97c18a457 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/sax.py @@ -0,0 +1,285 @@ +""" +SAX-based adapter to copy trees from/to the Python standard library. + +Use the `ElementTreeContentHandler` class to build an ElementTree from +SAX events. + +Use the `ElementTreeProducer` class or the `saxify()` function to fire +the SAX events of an ElementTree against a SAX ContentHandler. + +See https://lxml.de/sax.html +""" + + +from xml.sax.handler import ContentHandler +from lxml import etree +from lxml.etree import ElementTree, SubElement +from lxml.etree import Comment, ProcessingInstruction + +try: + from types import GenericAlias as _GenericAlias +except ImportError: + # Python 3.8 - we only need this as return value from "__class_getitem__" + def _GenericAlias(cls, item): + return f"{cls.__name__}[{item.__name__}]" + + +class SaxError(etree.LxmlError): + """General SAX error. + """ + + +def _getNsTag(tag): + if tag[0] == '{' and '}' in tag: + return tuple(tag[1:].split('}', 1)) + else: + return None, tag + + +class ElementTreeContentHandler(ContentHandler): + """Build an lxml ElementTree from SAX events. + """ + def __init__(self, makeelement=None): + ContentHandler.__init__(self) + self._root = None + self._root_siblings = [] + self._element_stack = [] + self._default_ns = None + self._ns_mapping = { None : [None] } + self._new_mappings = {} + if makeelement is None: + makeelement = etree.Element + self._makeelement = makeelement + + def _get_etree(self): + "Contains the generated ElementTree after parsing is finished." + return ElementTree(self._root) + + etree = property(_get_etree, doc=_get_etree.__doc__) + + def setDocumentLocator(self, locator): + pass + + def startDocument(self): + pass + + def endDocument(self): + pass + + def startPrefixMapping(self, prefix, uri): + self._new_mappings[prefix] = uri + try: + self._ns_mapping[prefix].append(uri) + except KeyError: + self._ns_mapping[prefix] = [uri] + if prefix is None: + self._default_ns = uri + + def endPrefixMapping(self, prefix): + ns_uri_list = self._ns_mapping[prefix] + ns_uri_list.pop() + if prefix is None: + self._default_ns = ns_uri_list[-1] + + def _buildTag(self, ns_name_tuple): + ns_uri, local_name = ns_name_tuple + if ns_uri: + el_tag = "{%s}%s" % ns_name_tuple + elif self._default_ns: + el_tag = "{%s}%s" % (self._default_ns, local_name) + else: + el_tag = local_name + return el_tag + + def startElementNS(self, ns_name, qname, attributes=None): + el_name = self._buildTag(ns_name) + if attributes: + attrs = {} + try: + iter_attributes = attributes.iteritems() + except AttributeError: + iter_attributes = attributes.items() + + for name_tuple, value in iter_attributes: + if name_tuple[0]: + attr_name = "{%s}%s" % name_tuple + else: + attr_name = name_tuple[1] + attrs[attr_name] = value + else: + attrs = None + + element_stack = self._element_stack + if self._root is None: + element = self._root = \ + self._makeelement(el_name, attrs, self._new_mappings) + if self._root_siblings and hasattr(element, 'addprevious'): + for sibling in self._root_siblings: + element.addprevious(sibling) + del self._root_siblings[:] + else: + element = SubElement(element_stack[-1], el_name, + attrs, self._new_mappings) + element_stack.append(element) + + self._new_mappings.clear() + + def processingInstruction(self, target, data): + pi = ProcessingInstruction(target, data) + if self._root is None: + self._root_siblings.append(pi) + else: + self._element_stack[-1].append(pi) + + def endElementNS(self, ns_name, qname): + element = self._element_stack.pop() + el_tag = self._buildTag(ns_name) + if el_tag != element.tag: + raise SaxError("Unexpected element closed: " + el_tag) + + def startElement(self, name, attributes=None): + if attributes: + attributes = {(None, k): v for k, v in attributes.items()} + self.startElementNS((None, name), name, attributes) + + def endElement(self, name): + self.endElementNS((None, name), name) + + def characters(self, data): + last_element = self._element_stack[-1] + try: + # if there already is a child element, we must append to its tail + last_element = last_element[-1] + except IndexError: + # otherwise: append to the text + last_element.text = (last_element.text or '') + data + else: + last_element.tail = (last_element.tail or '') + data + + ignorableWhitespace = characters + + # Allow subscripting sax.ElementTreeContentHandler in type annotions (PEP 560) + def __class_getitem__(cls, item): + return _GenericAlias(cls, item) + + +class ElementTreeProducer: + """Produces SAX events for an element and children. + """ + def __init__(self, element_or_tree, content_handler): + try: + element = element_or_tree.getroot() + except AttributeError: + element = element_or_tree + self._element = element + self._content_handler = content_handler + from xml.sax.xmlreader import AttributesNSImpl as attr_class + self._attr_class = attr_class + self._empty_attributes = attr_class({}, {}) + + def saxify(self): + self._content_handler.startDocument() + + element = self._element + if hasattr(element, 'getprevious'): + siblings = [] + sibling = element.getprevious() + while getattr(sibling, 'tag', None) is ProcessingInstruction: + siblings.append(sibling) + sibling = sibling.getprevious() + for sibling in siblings[::-1]: + self._recursive_saxify(sibling, {}) + + self._recursive_saxify(element, {}) + + if hasattr(element, 'getnext'): + sibling = element.getnext() + while getattr(sibling, 'tag', None) is ProcessingInstruction: + self._recursive_saxify(sibling, {}) + sibling = sibling.getnext() + + self._content_handler.endDocument() + + def _recursive_saxify(self, element, parent_nsmap): + content_handler = self._content_handler + tag = element.tag + if tag is Comment or tag is ProcessingInstruction: + if tag is ProcessingInstruction: + content_handler.processingInstruction( + element.target, element.text) + tail = element.tail + if tail: + content_handler.characters(tail) + return + + element_nsmap = element.nsmap + new_prefixes = [] + if element_nsmap != parent_nsmap: + # There have been updates to the namespace + for prefix, ns_uri in element_nsmap.items(): + if parent_nsmap.get(prefix) != ns_uri: + new_prefixes.append( (prefix, ns_uri) ) + + attribs = element.items() + if attribs: + attr_values = {} + attr_qnames = {} + for attr_ns_name, value in attribs: + attr_ns_tuple = _getNsTag(attr_ns_name) + attr_values[attr_ns_tuple] = value + attr_qnames[attr_ns_tuple] = self._build_qname( + attr_ns_tuple[0], attr_ns_tuple[1], element_nsmap, + preferred_prefix=None, is_attribute=True) + sax_attributes = self._attr_class(attr_values, attr_qnames) + else: + sax_attributes = self._empty_attributes + + ns_uri, local_name = _getNsTag(tag) + qname = self._build_qname( + ns_uri, local_name, element_nsmap, element.prefix, is_attribute=False) + + for prefix, uri in new_prefixes: + content_handler.startPrefixMapping(prefix, uri) + content_handler.startElementNS( + (ns_uri, local_name), qname, sax_attributes) + text = element.text + if text: + content_handler.characters(text) + for child in element: + self._recursive_saxify(child, element_nsmap) + content_handler.endElementNS((ns_uri, local_name), qname) + for prefix, uri in new_prefixes: + content_handler.endPrefixMapping(prefix) + tail = element.tail + if tail: + content_handler.characters(tail) + + def _build_qname(self, ns_uri, local_name, nsmap, preferred_prefix, is_attribute): + if ns_uri is None: + return local_name + + if not is_attribute and nsmap.get(preferred_prefix) == ns_uri: + prefix = preferred_prefix + else: + # Pick the first matching prefix, in alphabetical order. + candidates = [ + pfx for (pfx, uri) in nsmap.items() + if pfx is not None and uri == ns_uri + ] + prefix = ( + candidates[0] if len(candidates) == 1 + else min(candidates) if candidates + else None + ) + + if prefix is None: + # Default namespace + return local_name + return prefix + ':' + local_name + + +def saxify(element_or_tree, content_handler): + """One-shot helper to generate SAX events from an XML tree and fire + them against a SAX ContentHandler. + """ + return ElementTreeProducer(element_or_tree, content_handler).saxify() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/schematron.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/schematron.pxi new file mode 100644 index 0000000000000000000000000000000000000000..650e34b2b4c562517479b56a8f6f45b7111efafd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/schematron.pxi @@ -0,0 +1,173 @@ +# support for Schematron validation +from lxml.includes cimport schematron + + +cdef class SchematronError(LxmlError): + """Base class of all Schematron errors. + """ + +cdef class SchematronParseError(SchematronError): + """Error while parsing an XML document as Schematron schema. + """ + +cdef class SchematronValidateError(SchematronError): + """Error while validating an XML document with a Schematron schema. + """ + + +################################################################################ +# Schematron + +cdef class Schematron(_Validator): + """Schematron(self, etree=None, file=None) + A Schematron validator. + + Pass a root Element or an ElementTree to turn it into a validator. + Alternatively, pass a filename as keyword argument 'file' to parse from + the file system. + + Schematron is a less well known, but very powerful schema language. The main + idea is to use the capabilities of XPath to put restrictions on the structure + and the content of XML documents. Here is a simple example:: + + >>> schematron = Schematron(XML(''' + ... + ... + ... + ... Attribute + ... is forbidden + ... + ... + ... + ... + ... ''')) + + >>> xml = XML(''' + ... + ... + ... + ... + ... ''') + + >>> schematron.validate(xml) + 0 + + >>> xml = XML(''' + ... + ... + ... + ... + ... ''') + + >>> schematron.validate(xml) + 1 + + Schematron was added to libxml2 in version 2.6.21. Before version 2.6.32, + however, Schematron lacked support for error reporting other than to stderr. + This version is therefore required to retrieve validation warnings and + errors in lxml. + """ + cdef schematron.xmlSchematron* _c_schema + cdef xmlDoc* _c_schema_doc + + def __init__(self, etree=None, *, file=None): + cdef _Document doc + cdef _Element root_node + cdef xmlNode* c_node + cdef char* c_href + cdef schematron.xmlSchematronParserCtxt* parser_ctxt = NULL + _Validator.__init__(self) + if not config.ENABLE_SCHEMATRON: + raise SchematronError, \ + "lxml.etree was compiled without Schematron support." + + import warnings + warnings.warn( + "The (non-ISO) Schematron feature is deprecated and will be removed from libxml2 and lxml. " + "Use 'lxml.isoschematron' instead.", + DeprecationWarning, + ) + + if etree is not None: + doc = _documentOrRaise(etree) + root_node = _rootNodeOrRaise(etree) + self._c_schema_doc = _copyDocRoot(doc._c_doc, root_node._c_node) + parser_ctxt = schematron.xmlSchematronNewDocParserCtxt(self._c_schema_doc) + elif file is not None: + filename = _getFilenameForFile(file) + if filename is None: + # XXX assume a string object + filename = file + filename = _encodeFilename(filename) + with self._error_log: + orig_loader = _register_document_loader() + parser_ctxt = schematron.xmlSchematronNewParserCtxt(_cstr(filename)) + _reset_document_loader(orig_loader) + else: + raise SchematronParseError, "No tree or file given" + + if parser_ctxt is NULL: + if self._c_schema_doc is not NULL: + tree.xmlFreeDoc(self._c_schema_doc) + self._c_schema_doc = NULL + raise MemoryError() + + try: + with self._error_log: + orig_loader = _register_document_loader() + self._c_schema = schematron.xmlSchematronParse(parser_ctxt) + _reset_document_loader(orig_loader) + finally: + schematron.xmlSchematronFreeParserCtxt(parser_ctxt) + + if self._c_schema is NULL: + raise SchematronParseError( + "Document is not a valid Schematron schema", + self._error_log) + + def __dealloc__(self): + schematron.xmlSchematronFree(self._c_schema) + if self._c_schema_doc is not NULL: + tree.xmlFreeDoc(self._c_schema_doc) + + def __call__(self, etree): + """__call__(self, etree) + + Validate doc using Schematron. + + Returns true if document is valid, false if not.""" + cdef _Document doc + cdef _Element root_node + cdef xmlDoc* c_doc + cdef schematron.xmlSchematronValidCtxt* valid_ctxt + cdef int ret + + assert self._c_schema is not NULL, "Schematron instance not initialised" + doc = _documentOrRaise(etree) + root_node = _rootNodeOrRaise(etree) + + valid_ctxt = schematron.xmlSchematronNewValidCtxt( + self._c_schema, schematron.XML_SCHEMATRON_OUT_ERROR) + if valid_ctxt is NULL: + raise MemoryError() + + try: + self._error_log.clear() + # Need a cast here because older libxml2 releases do not use 'const' in the functype. + schematron.xmlSchematronSetValidStructuredErrors( + valid_ctxt, _receiveError, self._error_log) + c_doc = _fakeRootDoc(doc._c_doc, root_node._c_node) + with nogil: + ret = schematron.xmlSchematronValidateDoc(valid_ctxt, c_doc) + _destroyFakeDoc(doc._c_doc, c_doc) + finally: + schematron.xmlSchematronFreeValidCtxt(valid_ctxt) + + if ret == -1: + raise SchematronValidateError( + "Internal error in Schematron validation", + self._error_log) + if ret == 0: + return True + else: + return False diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/serializer.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/serializer.pxi new file mode 100644 index 0000000000000000000000000000000000000000..5266bdf2bdc71a0bbdbe0c8702dbc43f5684ee28 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/serializer.pxi @@ -0,0 +1,1849 @@ +# XML serialization and output functions + +cdef object GzipFile +from gzip import GzipFile + + +cdef class SerialisationError(LxmlError): + """A libxml2 error that occurred during serialisation. + """ + + +cdef enum _OutputMethods: + OUTPUT_METHOD_XML + OUTPUT_METHOD_HTML + OUTPUT_METHOD_TEXT + + +cdef int _findOutputMethod(method) except -1: + if method is None: + return OUTPUT_METHOD_XML + method = method.lower() + if method == "xml": + return OUTPUT_METHOD_XML + if method == "html": + return OUTPUT_METHOD_HTML + if method == "text": + return OUTPUT_METHOD_TEXT + raise ValueError(f"unknown output method {method!r}") + + +cdef _textToString(xmlNode* c_node, encoding, bint with_tail): + cdef bint needs_conversion + cdef const_xmlChar* c_text + cdef xmlNode* c_text_node + cdef tree.xmlBuffer* c_buffer + cdef int error_result + + c_buffer = tree.xmlBufferCreate() + if c_buffer is NULL: + raise MemoryError() + + with nogil: + error_result = tree.xmlNodeBufGetContent(c_buffer, c_node) + if with_tail: + c_text_node = _textNodeOrSkip(c_node.next) + while c_text_node is not NULL: + tree.xmlBufferWriteChar(c_buffer, c_text_node.content) + c_text_node = _textNodeOrSkip(c_text_node.next) + c_text = tree.xmlBufferContent(c_buffer) + + if error_result < 0 or c_text is NULL: + tree.xmlBufferFree(c_buffer) + raise SerialisationError, "Error during serialisation (out of memory?)" + + try: + needs_conversion = 0 + if encoding is unicode: + needs_conversion = 1 + elif encoding is not None: + # Python prefers lower case encoding names + encoding = encoding.lower() + if encoding not in ('utf8', 'utf-8'): + if encoding == 'ascii': + if isutf8l(c_text, tree.xmlBufferLength(c_buffer)): + # will raise a decode error below + needs_conversion = 1 + else: + needs_conversion = 1 + + if needs_conversion: + text = (c_text)[:tree.xmlBufferLength(c_buffer)].decode('utf8') + if encoding is not unicode: + encoding = _utf8(encoding) + text = python.PyUnicode_AsEncodedString( + text, encoding, 'strict') + else: + text = (c_text)[:tree.xmlBufferLength(c_buffer)] + finally: + tree.xmlBufferFree(c_buffer) + return text + + +cdef _tostring(_Element element, encoding, doctype, method, + bint write_xml_declaration, bint write_complete_document, + bint pretty_print, bint with_tail, int standalone): + """Serialize an element to an encoded string representation of its XML + tree. + """ + cdef tree.xmlOutputBuffer* c_buffer + cdef tree.xmlBuf* c_result_buffer + cdef tree.xmlCharEncodingHandler* enchandler + cdef const_char* c_enc + cdef const_xmlChar* c_version + cdef const_xmlChar* c_doctype + cdef int c_method + cdef int error_result + if element is None: + return None + _assertValidNode(element) + c_method = _findOutputMethod(method) + if c_method == OUTPUT_METHOD_TEXT: + return _textToString(element._c_node, encoding, with_tail) + if encoding is None or encoding is unicode: + c_enc = NULL + else: + encoding = _utf8(encoding) + c_enc = _cstr(encoding) + if doctype is None: + c_doctype = NULL + else: + doctype = _utf8(doctype) + c_doctype = _xcstr(doctype) + # it is necessary to *and* find the encoding handler *and* use + # encoding during output + enchandler = tree.xmlFindCharEncodingHandler(c_enc) + if enchandler is NULL and c_enc is not NULL: + if encoding is not None: + encoding = encoding.decode('UTF-8') + raise LookupError, f"unknown encoding: '{encoding}'" + c_buffer = tree.xmlAllocOutputBuffer(enchandler) + if c_buffer is NULL: + tree.xmlCharEncCloseFunc(enchandler) + raise MemoryError() + + with nogil: + _writeNodeToBuffer(c_buffer, element._c_node, c_enc, c_doctype, c_method, + write_xml_declaration, write_complete_document, + pretty_print, with_tail, standalone) + tree.xmlOutputBufferFlush(c_buffer) + if c_buffer.conv is not NULL: + c_result_buffer = c_buffer.conv + else: + c_result_buffer = c_buffer.buffer + + error_result = c_buffer.error + if error_result != xmlerror.XML_ERR_OK: + tree.xmlOutputBufferClose(c_buffer) + _raiseSerialisationError(error_result) + + try: + if encoding is unicode: + result = (tree.xmlBufContent( + c_result_buffer))[:tree.xmlBufUse(c_result_buffer)].decode('UTF-8') + else: + result = (tree.xmlBufContent( + c_result_buffer))[:tree.xmlBufUse(c_result_buffer)] + finally: + error_result = tree.xmlOutputBufferClose(c_buffer) + if error_result == -1: + _raiseSerialisationError(error_result) + return result + +cdef bytes _tostringC14N(element_or_tree, bint exclusive, bint with_comments, inclusive_ns_prefixes): + cdef xmlDoc* c_doc + cdef xmlChar* c_buffer = NULL + cdef int byte_count = -1 + cdef bytes result + cdef _Document doc + cdef _Element element + cdef xmlChar **c_inclusive_ns_prefixes + + if isinstance(element_or_tree, _Element): + _assertValidNode(<_Element>element_or_tree) + doc = (<_Element>element_or_tree)._doc + c_doc = _plainFakeRootDoc(doc._c_doc, (<_Element>element_or_tree)._c_node, 0) + else: + doc = _documentOrRaise(element_or_tree) + _assertValidDoc(doc) + c_doc = doc._c_doc + + c_inclusive_ns_prefixes = _convert_ns_prefixes(c_doc.dict, inclusive_ns_prefixes) if inclusive_ns_prefixes else NULL + try: + with nogil: + byte_count = c14n.xmlC14NDocDumpMemory( + c_doc, NULL, exclusive, c_inclusive_ns_prefixes, with_comments, &c_buffer) + + finally: + _destroyFakeDoc(doc._c_doc, c_doc) + if c_inclusive_ns_prefixes is not NULL: + python.lxml_free(c_inclusive_ns_prefixes) + + if byte_count < 0 or c_buffer is NULL: + if c_buffer is not NULL: + tree.xmlFree(c_buffer) + raise C14NError, "C14N failed" + try: + result = c_buffer[:byte_count] + finally: + tree.xmlFree(c_buffer) + return result + +cdef _raiseSerialisationError(int error_result): + if error_result == xmlerror.XML_ERR_NO_MEMORY: + raise MemoryError() + message = ErrorTypes._getName(error_result) + if message is None: + message = f"unknown error {error_result}" + raise SerialisationError, message + +############################################################ +# low-level serialisation functions + +cdef void _writeDoctype(tree.xmlOutputBuffer* c_buffer, + const_xmlChar* c_doctype) noexcept nogil: + tree.xmlOutputBufferWrite(c_buffer, tree.xmlStrlen(c_doctype), + c_doctype) + tree.xmlOutputBufferWriteString(c_buffer, "\n") + +cdef void _writeNodeToBuffer(tree.xmlOutputBuffer* c_buffer, + xmlNode* c_node, const_char* encoding, const_xmlChar* c_doctype, + int c_method, bint write_xml_declaration, + bint write_complete_document, + bint pretty_print, bint with_tail, + int standalone) noexcept nogil: + cdef xmlNode* c_nsdecl_node + cdef xmlDoc* c_doc = c_node.doc + if write_xml_declaration and c_method == OUTPUT_METHOD_XML: + _writeDeclarationToBuffer(c_buffer, c_doc.version, encoding, standalone) + + # comments/processing instructions before doctype declaration + if write_complete_document and not c_buffer.error and c_doc.intSubset: + _writePrevSiblings(c_buffer, c_doc.intSubset, encoding, pretty_print) + + if c_doctype: + _writeDoctype(c_buffer, c_doctype) + # write internal DTD subset, preceding PIs/comments, etc. + if write_complete_document and not c_buffer.error: + if c_doctype is NULL: + _writeDtdToBuffer(c_buffer, c_doc, c_node.name, c_method, encoding) + _writePrevSiblings(c_buffer, c_node, encoding, pretty_print) + + c_nsdecl_node = c_node + if not c_node.parent or c_node.parent.type != tree.XML_DOCUMENT_NODE: + # copy the node and add namespaces from parents + # this is required to make libxml write them + c_nsdecl_node = tree.xmlCopyNode(c_node, 2) + if not c_nsdecl_node: + c_buffer.error = xmlerror.XML_ERR_NO_MEMORY + return + _copyParentNamespaces(c_node, c_nsdecl_node) + + c_nsdecl_node.parent = c_node.parent + c_nsdecl_node.children = c_node.children + c_nsdecl_node.last = c_node.last + + # write node + if c_method == OUTPUT_METHOD_HTML: + tree.htmlNodeDumpFormatOutput( + c_buffer, c_doc, c_nsdecl_node, encoding, pretty_print) + else: + tree.xmlNodeDumpOutput( + c_buffer, c_doc, c_nsdecl_node, 0, pretty_print, encoding) + + if c_nsdecl_node is not c_node: + # clean up + c_nsdecl_node.children = c_nsdecl_node.last = NULL + tree.xmlFreeNode(c_nsdecl_node) + + if c_buffer.error: + return + + # write tail, trailing comments, etc. + if with_tail: + _writeTail(c_buffer, c_node, encoding, c_method, pretty_print) + if write_complete_document: + _writeNextSiblings(c_buffer, c_node, encoding, pretty_print) + if pretty_print: + tree.xmlOutputBufferWrite(c_buffer, 1, "\n") + +cdef void _writeDeclarationToBuffer(tree.xmlOutputBuffer* c_buffer, + const_xmlChar* version, const_char* encoding, + int standalone) noexcept nogil: + if version is NULL: + version = "1.0" + tree.xmlOutputBufferWrite(c_buffer, 15, "version) + tree.xmlOutputBufferWrite(c_buffer, 12, "' encoding='") + tree.xmlOutputBufferWriteString(c_buffer, encoding) + if standalone == 0: + tree.xmlOutputBufferWrite(c_buffer, 20, "' standalone='no'?>\n") + elif standalone == 1: + tree.xmlOutputBufferWrite(c_buffer, 21, "' standalone='yes'?>\n") + else: + tree.xmlOutputBufferWrite(c_buffer, 4, "'?>\n") + +cdef void _writeDtdToBuffer(tree.xmlOutputBuffer* c_buffer, + xmlDoc* c_doc, const_xmlChar* c_root_name, + int c_method, const_char* encoding) noexcept nogil: + cdef tree.xmlDtd* c_dtd + cdef xmlNode* c_node + cdef char* quotechar + c_dtd = c_doc.intSubset + if not c_dtd or not c_dtd.name: + return + + # Name in document type declaration must match the root element tag. + # For XML, case sensitive match, for HTML insensitive. + if c_method == OUTPUT_METHOD_HTML: + if tree.xmlStrcasecmp(c_root_name, c_dtd.name) != 0: + return + else: + if tree.xmlStrcmp(c_root_name, c_dtd.name) != 0: + return + + tree.xmlOutputBufferWrite(c_buffer, 10, "c_dtd.name) + + cdef const_xmlChar* public_id = c_dtd.ExternalID + cdef const_xmlChar* sys_url = c_dtd.SystemID + if public_id and public_id[0] == b'\0': + public_id = NULL + if sys_url and sys_url[0] == b'\0': + sys_url = NULL + + if public_id: + tree.xmlOutputBufferWrite(c_buffer, 9, ' PUBLIC "') + tree.xmlOutputBufferWriteString(c_buffer, public_id) + if sys_url: + tree.xmlOutputBufferWrite(c_buffer, 2, '" ') + else: + tree.xmlOutputBufferWrite(c_buffer, 1, '"') + elif sys_url: + tree.xmlOutputBufferWrite(c_buffer, 8, ' SYSTEM ') + + if sys_url: + if tree.xmlStrchr(sys_url, b'"'): + quotechar = '\'' + else: + quotechar = '"' + tree.xmlOutputBufferWrite(c_buffer, 1, quotechar) + tree.xmlOutputBufferWriteString(c_buffer, sys_url) + tree.xmlOutputBufferWrite(c_buffer, 1, quotechar) + + if (not c_dtd.entities and not c_dtd.elements and + not c_dtd.attributes and not c_dtd.notations and + not c_dtd.pentities): + tree.xmlOutputBufferWrite(c_buffer, 2, '>\n') + return + + tree.xmlOutputBufferWrite(c_buffer, 3, ' [\n') + if c_dtd.notations and not c_buffer.error: + c_buf = tree.xmlBufferCreate() + if not c_buf: + c_buffer.error = xmlerror.XML_ERR_NO_MEMORY + return + tree.xmlDumpNotationTable(c_buf, c_dtd.notations) + tree.xmlOutputBufferWrite( + c_buffer, tree.xmlBufferLength(c_buf), + tree.xmlBufferContent(c_buf)) + tree.xmlBufferFree(c_buf) + c_node = c_dtd.children + while c_node and not c_buffer.error: + tree.xmlNodeDumpOutput(c_buffer, c_node.doc, c_node, 0, 0, encoding) + c_node = c_node.next + tree.xmlOutputBufferWrite(c_buffer, 3, "]>\n") + +cdef void _writeTail(tree.xmlOutputBuffer* c_buffer, xmlNode* c_node, + const_char* encoding, int c_method, bint pretty_print) noexcept nogil: + "Write the element tail." + c_node = c_node.next + while c_node and not c_buffer.error and c_node.type in ( + tree.XML_TEXT_NODE, tree.XML_CDATA_SECTION_NODE): + if c_method == OUTPUT_METHOD_HTML: + tree.htmlNodeDumpFormatOutput( + c_buffer, c_node.doc, c_node, encoding, pretty_print) + else: + tree.xmlNodeDumpOutput( + c_buffer, c_node.doc, c_node, 0, pretty_print, encoding) + c_node = c_node.next + +cdef void _writePrevSiblings(tree.xmlOutputBuffer* c_buffer, xmlNode* c_node, + const_char* encoding, bint pretty_print) noexcept nogil: + cdef xmlNode* c_sibling + if c_node.parent and _isElement(c_node.parent): + return + # we are at a root node, so add PI and comment siblings + c_sibling = c_node + while c_sibling.prev and \ + (c_sibling.prev.type == tree.XML_PI_NODE or + c_sibling.prev.type == tree.XML_COMMENT_NODE): + c_sibling = c_sibling.prev + while c_sibling is not c_node and not c_buffer.error: + tree.xmlNodeDumpOutput(c_buffer, c_node.doc, c_sibling, 0, + pretty_print, encoding) + if pretty_print: + tree.xmlOutputBufferWriteString(c_buffer, "\n") + c_sibling = c_sibling.next + +cdef void _writeNextSiblings(tree.xmlOutputBuffer* c_buffer, xmlNode* c_node, + const_char* encoding, bint pretty_print) noexcept nogil: + cdef xmlNode* c_sibling + if c_node.parent and _isElement(c_node.parent): + return + # we are at a root node, so add PI and comment siblings + c_sibling = c_node.next + while not c_buffer.error and c_sibling and \ + (c_sibling.type == tree.XML_PI_NODE or + c_sibling.type == tree.XML_COMMENT_NODE): + if pretty_print: + tree.xmlOutputBufferWriteString(c_buffer, "\n") + tree.xmlNodeDumpOutput(c_buffer, c_node.doc, c_sibling, 0, + pretty_print, encoding) + c_sibling = c_sibling.next + + +# copied and adapted from libxml2 (xmlBufAttrSerializeTxtContent()) +cdef _write_attr_string(tree.xmlOutputBuffer* buf, const char *string): + cdef const char *base + cdef const char *cur + + if string == NULL: + return + + base = cur = string + while cur[0] != 0: + if cur[0] == b'\n': + if base != cur: + tree.xmlOutputBufferWrite(buf, cur - base, base) + + tree.xmlOutputBufferWrite(buf, 5, " ") + cur += 1 + base = cur + + elif cur[0] == b'\r': + if base != cur: + tree.xmlOutputBufferWrite(buf, cur - base, base) + + tree.xmlOutputBufferWrite(buf, 5, " ") + cur += 1 + base = cur + + elif cur[0] == b'\t': + if base != cur: + tree.xmlOutputBufferWrite(buf, cur - base, base) + + tree.xmlOutputBufferWrite(buf, 4, " ") + cur += 1 + base = cur + + elif cur[0] == b'"': + if base != cur: + tree.xmlOutputBufferWrite(buf, cur - base, base) + + tree.xmlOutputBufferWrite(buf, 6, """) + cur += 1 + base = cur + + elif cur[0] == b'<': + if base != cur: + tree.xmlOutputBufferWrite(buf, cur - base, base) + + tree.xmlOutputBufferWrite(buf, 4, "<") + cur += 1 + base = cur + + elif cur[0] == b'>': + if base != cur: + tree.xmlOutputBufferWrite(buf, cur - base, base) + + tree.xmlOutputBufferWrite(buf, 4, ">") + cur += 1 + base = cur + elif cur[0] == b'&': + if base != cur: + tree.xmlOutputBufferWrite(buf, cur - base, base) + + tree.xmlOutputBufferWrite(buf, 5, "&") + cur += 1 + base = cur + + else: + # Leave further encoding and escaping to the buffer encoder. + cur += 1 + + if base != cur: + tree.xmlOutputBufferWrite(buf, cur - base, base) + + +cdef void _write_cdata_section(tree.xmlOutputBuffer* buf, const char* c_data, const char* c_end): + tree.xmlOutputBufferWrite(buf, 9, " limits.INT_MAX: + tree.xmlOutputBufferWrite(buf, limits.INT_MAX, c_data) + c_data += limits.INT_MAX + tree.xmlOutputBufferWrite(buf, c_end - c_data, c_data) + tree.xmlOutputBufferWrite(buf, 3, "]]>") + + +cdef _write_cdata_string(tree.xmlOutputBuffer* buf, bytes bstring): + cdef const char* c_data = bstring + cdef const char* c_end = c_data + len(bstring) + cdef const char* c_pos = c_data + cdef bint nothing_written = True + + while True: + c_pos = cstring_h.memchr(c_pos, b']', c_end - c_pos) + if not c_pos: + break + c_pos += 1 + next_char = c_pos[0] + c_pos += 1 + if next_char != b']': + continue + # Found ']]', c_pos points to next character. + while c_pos[0] == b']': + c_pos += 1 + if c_pos[0] != b'>': + if c_pos == c_end: + break + # c_pos[0] is neither ']' nor '>', continue with next character. + c_pos += 1 + continue + + # Write section up to ']]' and start next block at trailing '>'. + _write_cdata_section(buf, c_data, c_pos) + nothing_written = False + c_data = c_pos + c_pos += 1 + + if nothing_written or c_data < c_end: + _write_cdata_section(buf, c_data, c_end) + + +############################################################ +# output to file-like objects + +cdef object io_open +from io import open as io_open + +cdef object gzip +import gzip + +cdef object getwriter +from codecs import getwriter +cdef object utf8_writer = getwriter('utf8') + +cdef object contextmanager +from contextlib import contextmanager + +cdef object _open_utf8_file + +@contextmanager +def _open_utf8_file(file, compression=0): + file = _getFSPathOrObject(file) + if _isString(file): + if compression: + with gzip.GzipFile(file, mode='wb', compresslevel=compression) as zf: + yield utf8_writer(zf) + else: + with io_open(file, 'w', encoding='utf8') as f: + yield f + else: + if compression: + with gzip.GzipFile(fileobj=file, mode='wb', compresslevel=compression) as zf: + yield utf8_writer(zf) + else: + yield utf8_writer(file) + + +@cython.final +@cython.internal +cdef class _FilelikeWriter: + cdef object _filelike + cdef object _close_filelike + cdef _ExceptionContext _exc_context + cdef _ErrorLog error_log + + def __cinit__(self, filelike, exc_context=None, compression=None, close=False): + if compression is not None and compression > 0: + filelike = GzipFile( + fileobj=filelike, mode='wb', compresslevel=compression) + self._close_filelike = filelike.close + elif close: + self._close_filelike = filelike.close + self._filelike = filelike + if exc_context is None: + self._exc_context = _ExceptionContext() + else: + self._exc_context = exc_context + self.error_log = _ErrorLog() + + cdef tree.xmlOutputBuffer* _createOutputBuffer( + self, tree.xmlCharEncodingHandler* enchandler) except NULL: + cdef tree.xmlOutputBuffer* c_buffer + c_buffer = tree.xmlOutputBufferCreateIO( + _writeFilelikeWriter, _closeFilelikeWriter, + self, enchandler) + if c_buffer is NULL: + raise IOError, "Could not create I/O writer context." + return c_buffer + + cdef int write(self, char* c_buffer, int size) noexcept: + try: + if self._filelike is None: + raise IOError, "File is already closed" + py_buffer = c_buffer[:size] + self._filelike.write(py_buffer) + except: + size = -1 + self._exc_context._store_raised() + finally: + return size # and swallow any further exceptions + + cdef int close(self) noexcept: + retval = 0 + try: + if self._close_filelike is not None: + self._close_filelike() + # we should not close the file here as we didn't open it + self._filelike = None + except: + retval = -1 + self._exc_context._store_raised() + finally: + return retval # and swallow any further exceptions + +cdef int _writeFilelikeWriter(void* ctxt, char* c_buffer, int length) noexcept: + return (<_FilelikeWriter>ctxt).write(c_buffer, length) + +cdef int _closeFilelikeWriter(void* ctxt) noexcept: + return (<_FilelikeWriter>ctxt).close() + +cdef _tofilelike(f, _Element element, encoding, doctype, method, + bint write_xml_declaration, bint write_doctype, + bint pretty_print, bint with_tail, int standalone, + int compression): + cdef _FilelikeWriter writer = None + cdef tree.xmlOutputBuffer* c_buffer + cdef tree.xmlCharEncodingHandler* enchandler + cdef const_char* c_enc + cdef const_xmlChar* c_doctype + cdef int error_result + + c_method = _findOutputMethod(method) + if c_method == OUTPUT_METHOD_TEXT: + data = _textToString(element._c_node, encoding, with_tail) + if compression: + bytes_out = BytesIO() + with GzipFile(fileobj=bytes_out, mode='wb', compresslevel=compression) as gzip_file: + gzip_file.write(data) + data = bytes_out.getvalue() + f = _getFSPathOrObject(f) + if _isString(f): + filename8 = _encodeFilename(f) + with open(filename8, 'wb') as f: + f.write(data) + else: + f.write(data) + return + + if encoding is None: + c_enc = NULL + else: + encoding = _utf8(encoding) + c_enc = _cstr(encoding) + if doctype is None: + c_doctype = NULL + else: + doctype = _utf8(doctype) + c_doctype = _xcstr(doctype) + + writer = _create_output_buffer(f, c_enc, compression, &c_buffer, close=False) + if writer is None: + with nogil: + error_result = _serialise_node( + c_buffer, c_doctype, c_enc, element._c_node, c_method, + write_xml_declaration, write_doctype, pretty_print, with_tail, standalone) + else: + error_result = _serialise_node( + c_buffer, c_doctype, c_enc, element._c_node, c_method, + write_xml_declaration, write_doctype, pretty_print, with_tail, standalone) + + if writer is not None: + writer._exc_context._raise_if_stored() + if error_result != xmlerror.XML_ERR_OK: + _raiseSerialisationError(error_result) + + +cdef int _serialise_node(tree.xmlOutputBuffer* c_buffer, const_xmlChar* c_doctype, + const_char* c_enc, xmlNode* c_node, int c_method, + bint write_xml_declaration, bint write_doctype, bint pretty_print, + bint with_tail, int standalone) noexcept nogil: + _writeNodeToBuffer( + c_buffer, c_node, c_enc, c_doctype, c_method, + write_xml_declaration, write_doctype, pretty_print, with_tail, standalone) + error_result = c_buffer.error + if error_result == xmlerror.XML_ERR_OK: + error_result = tree.xmlOutputBufferClose(c_buffer) + if error_result != -1: + error_result = xmlerror.XML_ERR_OK + else: + tree.xmlOutputBufferClose(c_buffer) + return error_result + + +cdef _FilelikeWriter _create_output_buffer( + f, const_char* c_enc, int c_compression, + tree.xmlOutputBuffer** c_buffer_ret, bint close): + cdef tree.xmlOutputBuffer* c_buffer + cdef _FilelikeWriter writer + cdef bytes filename8 + enchandler = tree.xmlFindCharEncodingHandler(c_enc) + if enchandler is NULL: + raise LookupError( + f"unknown encoding: '{c_enc.decode('UTF-8') if c_enc is not NULL else u''}'") + try: + f = _getFSPathOrObject(f) + + if c_compression and not HAS_ZLIB_COMPRESSION and _isString(f): + # Let "_FilelikeWriter" fall back to Python's GzipFile. + f = open(f, mode="wb") + close = True + + if _isString(f): + filename8 = _encodeFilename(f) + if b'%' in filename8 and ( + # Exclude absolute Windows paths and file:// URLs. + _isFilePath(filename8) not in (NO_FILE_PATH, ABS_WIN_FILE_PATH) + or filename8[:7].lower() == b'file://'): + # A file path (not a URL) containing the '%' URL escape character. + # libxml2 uses URL-unescaping on these, so escape the path before passing it in. + filename8 = filename8.replace(b'%', b'%25') + c_buffer = tree.xmlOutputBufferCreateFilename( + _cstr(filename8), enchandler, c_compression) + if c_buffer is NULL: + python.PyErr_SetFromErrno(IOError) # raises IOError + writer = None + elif hasattr(f, 'write'): + writer = _FilelikeWriter(f, compression=c_compression, close=close) + c_buffer = writer._createOutputBuffer(enchandler) + else: + raise TypeError( + f"File or filename expected, got '{python._fqtypename(f).decode('UTF-8')}'") + except: + tree.xmlCharEncCloseFunc(enchandler) + raise + c_buffer_ret[0] = c_buffer + return writer + +cdef xmlChar **_convert_ns_prefixes(tree.xmlDict* c_dict, ns_prefixes) except NULL: + cdef size_t i, num_ns_prefixes = len(ns_prefixes) + # Need to allocate one extra memory block to handle last NULL entry + c_ns_prefixes = python.lxml_malloc(num_ns_prefixes + 1, sizeof(xmlChar*)) + if not c_ns_prefixes: + raise MemoryError() + i = 0 + try: + for prefix in ns_prefixes: + prefix_utf = _utf8(prefix) + c_prefix_len = len(prefix_utf) + if c_prefix_len > limits.INT_MAX: + raise ValueError("Prefix too long") + c_prefix = tree.xmlDictExists(c_dict, _xcstr(prefix_utf), c_prefix_len) + if c_prefix: + # unknown prefixes do not need to get serialised + c_ns_prefixes[i] = c_prefix + i += 1 + except: + python.lxml_free(c_ns_prefixes) + raise + + c_ns_prefixes[i] = NULL # append end marker + return c_ns_prefixes + +cdef _tofilelikeC14N(f, _Element element, bint exclusive, bint with_comments, + int compression, inclusive_ns_prefixes): + cdef _FilelikeWriter writer = None + cdef tree.xmlOutputBuffer* c_buffer + cdef xmlChar **c_inclusive_ns_prefixes = NULL + cdef char* c_filename + cdef xmlDoc* c_base_doc + cdef xmlDoc* c_doc + cdef int bytes_count, error = 0 + + c_base_doc = element._c_node.doc + c_doc = _fakeRootDoc(c_base_doc, element._c_node) + try: + c_inclusive_ns_prefixes = ( + _convert_ns_prefixes(c_doc.dict, inclusive_ns_prefixes) + if inclusive_ns_prefixes else NULL) + + f = _getFSPathOrObject(f) + + close = False + if compression and not HAS_ZLIB_COMPRESSION and _isString(f): + # Let "_FilelikeWriter" fall back to Python's GzipFile. + f = open(f, mode="wb") + close = True + + if _isString(f): + filename8 = _encodeFilename(f) + c_filename = _cstr(filename8) + with nogil: + error = c14n.xmlC14NDocSave( + c_doc, NULL, exclusive, c_inclusive_ns_prefixes, + with_comments, c_filename, compression) + elif hasattr(f, 'write'): + writer = _FilelikeWriter(f, compression=compression, close=close) + c_buffer = writer._createOutputBuffer(NULL) + try: + with writer.error_log: + bytes_count = c14n.xmlC14NDocSaveTo( + c_doc, NULL, exclusive, c_inclusive_ns_prefixes, + with_comments, c_buffer) + finally: + error = tree.xmlOutputBufferClose(c_buffer) + if bytes_count < 0: + error = bytes_count + elif error != -1: + error = xmlerror.XML_ERR_OK + else: + raise TypeError(f"File or filename expected, got '{python._fqtypename(f).decode('UTF-8')}'") + finally: + _destroyFakeDoc(c_base_doc, c_doc) + if c_inclusive_ns_prefixes is not NULL: + python.lxml_free(c_inclusive_ns_prefixes) + + if writer is not None: + writer._exc_context._raise_if_stored() + + if error < 0: + message = "C14N failed" + if writer is not None: + errors = writer.error_log + if len(errors): + message = errors[0].message + raise C14NError(message) + + +# C14N 2.0 + +def canonicalize(xml_data=None, *, out=None, from_file=None, **options): + """Convert XML to its C14N 2.0 serialised form. + + If *out* is provided, it must be a file or file-like object that receives + the serialised canonical XML output (text, not bytes) through its ``.write()`` + method. To write to a file, open it in text mode with encoding "utf-8". + If *out* is not provided, this function returns the output as text string. + + Either *xml_data* (an XML string, tree or Element) or *file* + (a file path or file-like object) must be provided as input. + + The configuration options are the same as for the ``C14NWriterTarget``. + """ + if xml_data is None and from_file is None: + raise ValueError("Either 'xml_data' or 'from_file' must be provided as input") + + sio = None + if out is None: + sio = out = StringIO() + + target = C14NWriterTarget(out.write, **options) + + if xml_data is not None and not isinstance(xml_data, basestring): + _tree_to_target(xml_data, target) + return sio.getvalue() if sio is not None else None + + cdef _FeedParser parser = XMLParser( + target=target, + attribute_defaults=True, + collect_ids=False, + ) + + if xml_data is not None: + parser.feed(xml_data) + parser.close() + elif from_file is not None: + try: + _parseDocument(from_file, parser, base_url=None) + except _TargetParserResult: + pass + + return sio.getvalue() if sio is not None else None + + +cdef _tree_to_target(element, target): + for event, elem in iterwalk(element, events=('start', 'end', 'start-ns', 'comment', 'pi')): + text = None + if event == 'start': + target.start(elem.tag, elem.attrib) + text = elem.text + elif event == 'end': + target.end(elem.tag) + text = elem.tail + elif event == 'start-ns': + target.start_ns(*elem) + continue + elif event == 'comment': + target.comment(elem.text) + text = elem.tail + elif event == 'pi': + target.pi(elem.target, elem.text) + text = elem.tail + if text: + target.data(text) + return target.close() + + +cdef object _looks_like_prefix_name = re.compile(r'^\w+:\w+$', re.UNICODE).match + + +cdef class C14NWriterTarget: + """ + Canonicalization writer target for the XMLParser. + + Serialises parse events to XML C14N 2.0. + + Configuration options: + + - *with_comments*: set to true to include comments + - *strip_text*: set to true to strip whitespace before and after text content + - *rewrite_prefixes*: set to true to replace namespace prefixes by "n{number}" + - *qname_aware_tags*: a set of qname aware tag names in which prefixes + should be replaced in text content + - *qname_aware_attrs*: a set of qname aware attribute names in which prefixes + should be replaced in text content + - *exclude_attrs*: a set of attribute names that should not be serialised + - *exclude_tags*: a set of tag names that should not be serialised + """ + cdef object _write + cdef list _data + cdef set _qname_aware_tags + cdef object _find_qname_aware_attrs + cdef list _declared_ns_stack + cdef list _ns_stack + cdef dict _prefix_map + cdef list _preserve_space + cdef tuple _pending_start + cdef set _exclude_tags + cdef set _exclude_attrs + cdef Py_ssize_t _ignored_depth + cdef bint _with_comments + cdef bint _strip_text + cdef bint _rewrite_prefixes + cdef bint _root_seen + cdef bint _root_done + + def __init__(self, write, *, + with_comments=False, strip_text=False, rewrite_prefixes=False, + qname_aware_tags=None, qname_aware_attrs=None, + exclude_attrs=None, exclude_tags=None): + self._write = write + self._data = [] + self._with_comments = with_comments + self._strip_text = strip_text + self._exclude_attrs = set(exclude_attrs) if exclude_attrs else None + self._exclude_tags = set(exclude_tags) if exclude_tags else None + + self._rewrite_prefixes = rewrite_prefixes + if qname_aware_tags: + self._qname_aware_tags = set(qname_aware_tags) + else: + self._qname_aware_tags = None + if qname_aware_attrs: + self._find_qname_aware_attrs = set(qname_aware_attrs).intersection + else: + self._find_qname_aware_attrs = None + + # Stack with globally and newly declared namespaces as (uri, prefix) pairs. + self._declared_ns_stack = [[ + ("http://www.w3.org/XML/1998/namespace", "xml"), + ]] + # Stack with user declared namespace prefixes as (uri, prefix) pairs. + self._ns_stack = [] + if not rewrite_prefixes: + self._ns_stack.append(_DEFAULT_NAMESPACE_PREFIXES_ITEMS) + self._ns_stack.append([]) + self._prefix_map = {} + self._preserve_space = [False] + self._pending_start = None + self._ignored_depth = 0 + self._root_seen = False + self._root_done = False + + def _iter_namespaces(self, ns_stack): + for namespaces in reversed(ns_stack): + if namespaces: # almost no element declares new namespaces + yield from namespaces + + cdef _resolve_prefix_name(self, prefixed_name): + prefix, name = prefixed_name.split(':', 1) + for uri, p in self._iter_namespaces(self._ns_stack): + if p == prefix: + return f'{{{uri}}}{name}' + raise ValueError(f'Prefix {prefix} of QName "{prefixed_name}" is not declared in scope') + + cdef _qname(self, qname, uri=None): + if uri is None: + uri, tag = qname[1:].rsplit('}', 1) if qname[:1] == '{' else ('', qname) + else: + tag = qname + + prefixes_seen = set() + for u, prefix in self._iter_namespaces(self._declared_ns_stack): + if u == uri and prefix not in prefixes_seen: + return f'{prefix}:{tag}' if prefix else tag, tag, uri + prefixes_seen.add(prefix) + + # Not declared yet => add new declaration. + if self._rewrite_prefixes: + if uri in self._prefix_map: + prefix = self._prefix_map[uri] + else: + prefix = self._prefix_map[uri] = f'n{len(self._prefix_map)}' + self._declared_ns_stack[-1].append((uri, prefix)) + return f'{prefix}:{tag}', tag, uri + + if not uri and '' not in prefixes_seen: + # No default namespace declared => no prefix needed. + return tag, tag, uri + + for u, prefix in self._iter_namespaces(self._ns_stack): + if u == uri: + self._declared_ns_stack[-1].append((uri, prefix)) + return f'{prefix}:{tag}' if prefix else tag, tag, uri + + if not uri: + # As soon as a default namespace is defined, + # anything that has no namespace (and thus, no prefix) goes there. + return tag, tag, uri + + raise ValueError(f'Namespace "{uri}" of name "{tag}" is not declared in scope') + + def data(self, data): + if not self._ignored_depth: + self._data.append(data) + + cdef _flush(self): + cdef unicode data = ''.join(self._data) + del self._data[:] + if self._strip_text and not self._preserve_space[-1]: + data = data.strip() + if self._pending_start is not None: + (tag, attrs, new_namespaces), self._pending_start = self._pending_start, None + qname_text = data if ':' in data and _looks_like_prefix_name(data) else None + self._start(tag, attrs, new_namespaces, qname_text) + if qname_text is not None: + return + if data and self._root_seen: + self._write(_escape_cdata_c14n(data)) + + def start_ns(self, prefix, uri): + if self._ignored_depth: + return + # we may have to resolve qnames in text content + if self._data: + self._flush() + self._ns_stack[-1].append((uri, prefix)) + + def start(self, tag, attrs): + if self._exclude_tags is not None and ( + self._ignored_depth or tag in self._exclude_tags): + self._ignored_depth += 1 + return + if self._data: + self._flush() + + new_namespaces = [] + self._declared_ns_stack.append(new_namespaces) + + if self._qname_aware_tags is not None and tag in self._qname_aware_tags: + # Need to parse text first to see if it requires a prefix declaration. + self._pending_start = (tag, attrs, new_namespaces) + return + self._start(tag, attrs, new_namespaces) + + cdef _start(self, tag, attrs, new_namespaces, qname_text=None): + if self._exclude_attrs is not None and attrs: + attrs = {k: v for k, v in attrs.items() if k not in self._exclude_attrs} + + qnames = {tag, *attrs} + resolved_names = {} + + # Resolve prefixes in attribute and tag text. + if qname_text is not None: + qname = resolved_names[qname_text] = self._resolve_prefix_name(qname_text) + qnames.add(qname) + if self._find_qname_aware_attrs is not None and attrs: + qattrs = self._find_qname_aware_attrs(attrs) + if qattrs: + for attr_name in qattrs: + value = attrs[attr_name] + if _looks_like_prefix_name(value): + qname = resolved_names[value] = self._resolve_prefix_name(value) + qnames.add(qname) + else: + qattrs = None + else: + qattrs = None + + # Assign prefixes in lexicographical order of used URIs. + parsed_qnames = {n: self._qname(n) for n in sorted( + qnames, key=lambda n: n.split('}', 1))} + + # Write namespace declarations in prefix order ... + if new_namespaces: + attr_list = [ + ('xmlns:' + prefix if prefix else 'xmlns', uri) + for uri, prefix in new_namespaces + ] + attr_list.sort() + else: + # almost always empty + attr_list = [] + + # ... followed by attributes in URI+name order + if attrs: + for k, v in sorted(attrs.items()): + if qattrs is not None and k in qattrs and v in resolved_names: + v = parsed_qnames[resolved_names[v]][0] + attr_qname, attr_name, uri = parsed_qnames[k] + # No prefix for attributes in default ('') namespace. + attr_list.append((attr_qname if uri else attr_name, v)) + + # Honour xml:space attributes. + space_behaviour = attrs.get('{http://www.w3.org/XML/1998/namespace}space') + self._preserve_space.append( + space_behaviour == 'preserve' if space_behaviour + else self._preserve_space[-1]) + + # Write the tag. + write = self._write + write('<' + parsed_qnames[tag][0]) + if attr_list: + write(''.join([f' {k}="{_escape_attrib_c14n(v)}"' for k, v in attr_list])) + write('>') + + # Write the resolved qname text content. + if qname_text is not None: + write(_escape_cdata_c14n(parsed_qnames[resolved_names[qname_text]][0])) + + self._root_seen = True + self._ns_stack.append([]) + + def end(self, tag): + if self._ignored_depth: + self._ignored_depth -= 1 + return + if self._data: + self._flush() + self._write(f'') + self._preserve_space.pop() + self._root_done = len(self._preserve_space) == 1 + self._declared_ns_stack.pop() + self._ns_stack.pop() + + def comment(self, text): + if not self._with_comments: + return + if self._ignored_depth: + return + if self._root_done: + self._write('\n') + elif self._root_seen and self._data: + self._flush() + self._write(f'') + if not self._root_seen: + self._write('\n') + + def pi(self, target, data): + if self._ignored_depth: + return + if self._root_done: + self._write('\n') + elif self._root_seen and self._data: + self._flush() + self._write( + f'' if data else f'') + if not self._root_seen: + self._write('\n') + + def close(self): + return None + + +cdef _raise_serialization_error(text): + raise TypeError("cannot serialize %r (type %s)" % (text, type(text).__name__)) + + +cdef unicode _escape_cdata_c14n(stext): + # escape character data + cdef unicode text + cdef Py_UCS4 ch + cdef Py_ssize_t start = 0, pos = 0 + cdef list substrings = None + try: + text = unicode(stext) + except (TypeError, AttributeError): + return _raise_serialization_error(stext) + + for pos, ch in enumerate(text): + if ch == '&': + escape = '&' + elif ch == '<': + escape = '<' + elif ch == '>': + escape = '>' + elif ch == '\r': + escape = ' ' + else: + continue + + if substrings is None: + substrings = [] + if pos > start: + substrings.append(text[start:pos]) + substrings.append(escape) + start = pos + 1 + + if substrings is None: + return text + if pos >= start: + substrings.append(text[start:pos+1]) + return ''.join(substrings) + + +cdef unicode _escape_attrib_c14n(stext): + # escape attribute value + cdef unicode text + cdef Py_UCS4 ch + cdef Py_ssize_t start = 0, pos = 0 + cdef list substrings = None + try: + text = unicode(stext) + except (TypeError, AttributeError): + return _raise_serialization_error(stext) + + for pos, ch in enumerate(text): + if ch == '&': + escape = '&' + elif ch == '<': + escape = '<' + elif ch == '"': + escape = '"' + elif ch == '\t': + escape = ' ' + elif ch == '\n': + escape = ' ' + elif ch == '\r': + escape = ' ' + else: + continue + + if substrings is None: + substrings = [] + if pos > start: + substrings.append(text[start:pos]) + substrings.append(escape) + start = pos + 1 + + if substrings is None: + return text + if pos >= start: + substrings.append(text[start:pos+1]) + return ''.join(substrings) + + +# incremental serialisation + +cdef class xmlfile: + """xmlfile(self, output_file, encoding=None, compression=None, close=False, buffered=True) + + A simple mechanism for incremental XML serialisation. + + Usage example:: + + with xmlfile("somefile.xml", encoding='utf-8') as xf: + xf.write_declaration(standalone=True) + xf.write_doctype('') + + # generate an element (the root element) + with xf.element('root'): + # write a complete Element into the open root element + xf.write(etree.Element('test')) + + # generate and write more Elements, e.g. through iterparse + for element in generate_some_elements(): + # serialise generated elements into the XML file + xf.write(element) + + # or write multiple Elements or strings at once + xf.write(etree.Element('start'), "text", etree.Element('end')) + + If 'output_file' is a file(-like) object, passing ``close=True`` will + close it when exiting the context manager. By default, it is left + to the owner to do that. When a file path is used, lxml will take care + of opening and closing the file itself. Also, when a compression level + is set, lxml will deliberately close the file to make sure all data gets + compressed and written. + + Setting ``buffered=False`` will flush the output after each operation, + such as opening or closing an ``xf.element()`` block or calling + ``xf.write()``. Alternatively, calling ``xf.flush()`` can be used to + explicitly flush any pending output when buffering is enabled. + """ + cdef object output_file + cdef bytes encoding + cdef _IncrementalFileWriter writer + cdef _AsyncIncrementalFileWriter async_writer + cdef int compresslevel + cdef bint close + cdef bint buffered + cdef int method + + def __init__(self, output_file not None, encoding=None, compression=None, + close=False, buffered=True): + self.output_file = output_file + self.encoding = _utf8orNone(encoding) + self.compresslevel = compression or 0 + self.close = close + self.buffered = buffered + self.method = OUTPUT_METHOD_XML + + def __enter__(self): + assert self.output_file is not None + self.writer = _IncrementalFileWriter( + self.output_file, self.encoding, self.compresslevel, + self.close, self.buffered, self.method) + return self.writer + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.writer is not None: + old_writer, self.writer = self.writer, None + raise_on_error = exc_type is None + old_writer._close(raise_on_error) + if self.close: + self.output_file = None + + async def __aenter__(self): + assert self.output_file is not None + if isinstance(self.output_file, basestring): + raise TypeError("Cannot asynchronously write to a plain file") + if not hasattr(self.output_file, 'write'): + raise TypeError("Output file needs an async .write() method") + self.async_writer = _AsyncIncrementalFileWriter( + self.output_file, self.encoding, self.compresslevel, + self.close, self.buffered, self.method) + return self.async_writer + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.async_writer is not None: + old_writer, self.async_writer = self.async_writer, None + raise_on_error = exc_type is None + await old_writer._close(raise_on_error) + if self.close: + self.output_file = None + + +cdef class htmlfile(xmlfile): + """htmlfile(self, output_file, encoding=None, compression=None, close=False, buffered=True) + + A simple mechanism for incremental HTML serialisation. Works the same as + xmlfile. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.method = OUTPUT_METHOD_HTML + + +cdef enum _IncrementalFileWriterStatus: + WRITER_STARTING = 0 + WRITER_DECL_WRITTEN = 1 + WRITER_DTD_WRITTEN = 2 + WRITER_IN_ELEMENT = 3 + WRITER_FINISHED = 4 + + +@cython.final +@cython.internal +cdef class _IncrementalFileWriter: + cdef tree.xmlOutputBuffer* _c_out + cdef bytes _encoding + cdef const_char* _c_encoding + cdef _FilelikeWriter _target + cdef list _element_stack + cdef int _status + cdef int _method + cdef bint _buffered + + def __cinit__(self, outfile, bytes encoding, int compresslevel, bint close, + bint buffered, int method): + self._status = WRITER_STARTING + self._element_stack = [] + if encoding is None: + # We always need a document encoding to make the attribute serialisation + # of libxml2 identical to ours. + encoding = b'ASCII' + self._encoding = encoding + self._c_encoding = _cstr(encoding) + self._buffered = buffered + self._target = _create_output_buffer( + outfile, self._c_encoding, compresslevel, &self._c_out, close) + self._method = method + + def __dealloc__(self): + if self._c_out is not NULL: + tree.xmlOutputBufferClose(self._c_out) + + def write_declaration(self, version=None, standalone=None, doctype=None): + """write_declaration(self, version=None, standalone=None, doctype=None) + + Write an XML declaration and (optionally) a doctype into the file. + """ + assert self._c_out is not NULL + cdef const_xmlChar* c_version + cdef int c_standalone + if self._method != OUTPUT_METHOD_XML: + raise LxmlSyntaxError("only XML documents have declarations") + if self._status >= WRITER_DECL_WRITTEN: + raise LxmlSyntaxError("XML declaration already written") + version = _utf8orNone(version) + c_version = _xcstr(version) if version is not None else NULL + doctype = _utf8orNone(doctype) + if standalone is None: + c_standalone = -1 + else: + c_standalone = 1 if standalone else 0 + _writeDeclarationToBuffer(self._c_out, c_version, self._c_encoding, c_standalone) + if doctype is not None: + _writeDoctype(self._c_out, _xcstr(doctype)) + self._status = WRITER_DTD_WRITTEN + else: + self._status = WRITER_DECL_WRITTEN + if not self._buffered: + tree.xmlOutputBufferFlush(self._c_out) + self._handle_error(self._c_out.error) + + def write_doctype(self, doctype): + """write_doctype(self, doctype) + + Writes the given doctype declaration verbatimly into the file. + """ + assert self._c_out is not NULL + if doctype is None: + return + if self._status >= WRITER_DTD_WRITTEN: + raise LxmlSyntaxError("DOCTYPE already written or cannot write it here") + doctype = _utf8(doctype) + _writeDoctype(self._c_out, _xcstr(doctype)) + self._status = WRITER_DTD_WRITTEN + if not self._buffered: + tree.xmlOutputBufferFlush(self._c_out) + self._handle_error(self._c_out.error) + + def method(self, method): + """method(self, method) + + Returns a context manager that overrides and restores the output method. + method is one of (None, 'xml', 'html') where None means 'xml'. + """ + assert self._c_out is not NULL + c_method = self._method if method is None else _findOutputMethod(method) + return _MethodChanger(self, c_method) + + def element(self, tag, attrib=None, nsmap=None, method=None, **_extra): + """element(self, tag, attrib=None, nsmap=None, method, **_extra) + + Returns a context manager that writes an opening and closing tag. + method is one of (None, 'xml', 'html') where None means 'xml'. + """ + assert self._c_out is not NULL + attributes = [] + if attrib is not None: + for name, value in _iter_attrib(attrib): + if name not in _extra: + ns, name = _getNsTag(name) + attributes.append((ns, name, _utf8(value))) + if _extra: + for name, value in _extra.iteritems(): + ns, name = _getNsTag(name) + attributes.append((ns, name, _utf8(value))) + reversed_nsmap = {} + if nsmap: + for prefix, ns in nsmap.items(): + if prefix is not None: + prefix = _utf8(prefix) + _prefixValidOrRaise(prefix) + reversed_nsmap[_utf8(ns)] = prefix + ns, name = _getNsTag(tag) + + c_method = self._method if method is None else _findOutputMethod(method) + + return _FileWriterElement(self, (ns, name, attributes, reversed_nsmap), c_method) + + cdef _write_qname(self, bytes name, bytes prefix): + if prefix: # empty bytes for no prefix (not None to allow sorting) + tree.xmlOutputBufferWrite(self._c_out, len(prefix), _cstr(prefix)) + tree.xmlOutputBufferWrite(self._c_out, 1, ':') + tree.xmlOutputBufferWrite(self._c_out, len(name), _cstr(name)) + + cdef _write_start_element(self, element_config): + if self._status > WRITER_IN_ELEMENT: + raise LxmlSyntaxError("cannot append trailing element to complete XML document") + ns, name, attributes, nsmap = element_config + flat_namespace_map, new_namespaces = self._collect_namespaces(nsmap) + prefix = self._find_prefix(ns, flat_namespace_map, new_namespaces) + tree.xmlOutputBufferWrite(self._c_out, 1, '<') + self._write_qname(name, prefix) + + self._write_attributes_and_namespaces( + attributes, flat_namespace_map, new_namespaces) + + tree.xmlOutputBufferWrite(self._c_out, 1, '>') + if not self._buffered: + tree.xmlOutputBufferFlush(self._c_out) + self._handle_error(self._c_out.error) + + self._element_stack.append((ns, name, prefix, flat_namespace_map)) + self._status = WRITER_IN_ELEMENT + + cdef _write_attributes_and_namespaces(self, list attributes, + dict flat_namespace_map, + list new_namespaces): + if attributes: + # _find_prefix() may append to new_namespaces => build them first + attributes = [ + (self._find_prefix(ns, flat_namespace_map, new_namespaces), name, value) + for ns, name, value in attributes ] + if new_namespaces: + new_namespaces.sort() + self._write_attributes_list(new_namespaces) + if attributes: + self._write_attributes_list(attributes) + + cdef _write_attributes_list(self, list attributes): + for prefix, name, value in attributes: + tree.xmlOutputBufferWrite(self._c_out, 1, ' ') + self._write_qname(name, prefix) + tree.xmlOutputBufferWrite(self._c_out, 2, '="') + _write_attr_string(self._c_out, _cstr(value)) + + tree.xmlOutputBufferWrite(self._c_out, 1, '"') + + cdef _write_end_element(self, element_config): + if self._status != WRITER_IN_ELEMENT: + raise LxmlSyntaxError("not in an element") + if not self._element_stack or self._element_stack[-1][:2] != element_config[:2]: + raise LxmlSyntaxError("inconsistent exit action in context manager") + + # If previous write operations failed, the context manager exit might still call us. + # That is ok, but we stop writing closing tags and handling errors in that case. + # For all non-I/O errors, we continue writing closing tags if we can. + ok_to_write = self._c_out.error == xmlerror.XML_ERR_OK + + name, prefix = self._element_stack.pop()[1:3] + if ok_to_write: + tree.xmlOutputBufferWrite(self._c_out, 2, '') + + if not self._element_stack: + self._status = WRITER_FINISHED + if ok_to_write: + if not self._buffered: + tree.xmlOutputBufferFlush(self._c_out) + self._handle_error(self._c_out.error) + + cdef _find_prefix(self, bytes href, dict flat_namespaces_map, list new_namespaces): + if href is None: + return None + if href in flat_namespaces_map: + return flat_namespaces_map[href] + # need to create a new prefix + prefixes = flat_namespaces_map.values() + i = 0 + while True: + prefix = _utf8('ns%d' % i) + if prefix not in prefixes: + new_namespaces.append((b'xmlns', prefix, href)) + flat_namespaces_map[href] = prefix + return prefix + i += 1 + + cdef _collect_namespaces(self, dict nsmap): + new_namespaces = [] + flat_namespaces_map = {} + for ns, prefix in nsmap.iteritems(): + flat_namespaces_map[ns] = prefix + if prefix is None: + # use empty bytes rather than None to allow sorting + new_namespaces.append((b'', b'xmlns', ns)) + else: + new_namespaces.append((b'xmlns', prefix, ns)) + # merge in flat namespace map of parent + if self._element_stack: + for ns, prefix in (self._element_stack[-1][-1]).iteritems(): + if flat_namespaces_map.get(ns) is None: + # unknown or empty prefix => prefer a 'real' prefix + flat_namespaces_map[ns] = prefix + return flat_namespaces_map, new_namespaces + + def write(self, *args, bint with_tail=True, bint pretty_print=False, method=None): + """write(self, *args, with_tail=True, pretty_print=False, method=None) + + Write subtrees or strings into the file. + + If method is not None, it should be one of ('html', 'xml', 'text') + to temporarily override the output method. + """ + assert self._c_out is not NULL + c_method = self._method if method is None else _findOutputMethod(method) + + for content in args: + if _isString(content): + if self._status != WRITER_IN_ELEMENT: + if self._status > WRITER_IN_ELEMENT or content.strip(): + raise LxmlSyntaxError("not in an element") + bstring = _utf8(content) + if not bstring: + continue + + ns, name, _, _ = self._element_stack[-1] + if (c_method == OUTPUT_METHOD_HTML and + ns in (None, b'http://www.w3.org/1999/xhtml') and + name in (b'script', b'style')): + tree.xmlOutputBufferWrite(self._c_out, len(bstring), _cstr(bstring)) + + else: + tree.xmlOutputBufferWriteEscape(self._c_out, _xcstr(bstring), NULL) + + elif isinstance(content, CDATA): + if self._status > WRITER_IN_ELEMENT: + raise LxmlSyntaxError("not in an element") + _write_cdata_string(self._c_out, (content)._utf8_data) + + elif iselement(content): + if self._status > WRITER_IN_ELEMENT: + raise LxmlSyntaxError("cannot append trailing element to complete XML document") + _writeNodeToBuffer(self._c_out, (<_Element>content)._c_node, + self._c_encoding, NULL, c_method, + False, False, pretty_print, with_tail, False) + if (<_Element>content)._c_node.type == tree.XML_ELEMENT_NODE: + if not self._element_stack: + self._status = WRITER_FINISHED + + elif content is not None: + raise TypeError( + f"got invalid input value of type {type(content)}, expected string, CDATA or Element") + + self._handle_error(self._c_out.error) + + if not self._buffered: + tree.xmlOutputBufferFlush(self._c_out) + self._handle_error(self._c_out.error) + + def flush(self): + """flush(self) + + Write any pending content of the current output buffer to the stream. + """ + assert self._c_out is not NULL + tree.xmlOutputBufferFlush(self._c_out) + self._handle_error(self._c_out.error) + + cdef _close(self, bint raise_on_error): + if raise_on_error: + if self._status < WRITER_IN_ELEMENT: + raise LxmlSyntaxError("no content written") + if self._element_stack: + raise LxmlSyntaxError("pending open tags on close") + error_result = self._c_out.error + if error_result == xmlerror.XML_ERR_OK: + error_result = tree.xmlOutputBufferClose(self._c_out) + if error_result != -1: + error_result = xmlerror.XML_ERR_OK + else: + tree.xmlOutputBufferClose(self._c_out) + self._status = WRITER_FINISHED + self._c_out = NULL + del self._element_stack[:] + if raise_on_error: + self._handle_error(error_result) + + cdef _handle_error(self, int error_result): + if error_result != xmlerror.XML_ERR_OK: + if self._target is not None: + self._target._exc_context._raise_if_stored() + _raiseSerialisationError(error_result) + + +@cython.final +@cython.internal +cdef class _AsyncDataWriter: + cdef list _data + def __cinit__(self): + self._data = [] + + cdef bytes collect(self): + data = b''.join(self._data) + del self._data[:] + return data + + def write(self, data): + self._data.append(data) + + def close(self): + pass + + +@cython.final +@cython.internal +cdef class _AsyncIncrementalFileWriter: + cdef _IncrementalFileWriter _writer + cdef _AsyncDataWriter _buffer + cdef object _async_outfile + cdef int _flush_after_writes + cdef bint _should_close + cdef bint _buffered + + def __cinit__(self, async_outfile, bytes encoding, int compresslevel, bint close, + bint buffered, int method): + self._flush_after_writes = 20 + self._async_outfile = async_outfile + self._should_close = close + self._buffered = buffered + self._buffer = _AsyncDataWriter() + self._writer = _IncrementalFileWriter( + self._buffer, encoding, compresslevel, close=True, buffered=False, method=method) + + cdef bytes _flush(self): + if not self._buffered or len(self._buffer._data) > self._flush_after_writes: + return self._buffer.collect() + return None + + async def flush(self): + self._writer.flush() + data = self._buffer.collect() + if data: + await self._async_outfile.write(data) + + async def write_declaration(self, version=None, standalone=None, doctype=None): + self._writer.write_declaration(version, standalone, doctype) + data = self._flush() + if data: + await self._async_outfile.write(data) + + async def write_doctype(self, doctype): + self._writer.write_doctype(doctype) + data = self._flush() + if data: + await self._async_outfile.write(data) + + async def write(self, *args, with_tail=True, pretty_print=False, method=None): + self._writer.write(*args, with_tail=with_tail, pretty_print=pretty_print, method=method) + data = self._flush() + if data: + await self._async_outfile.write(data) + + def method(self, method): + return self._writer.method(method) + + def element(self, tag, attrib=None, nsmap=None, method=None, **_extra): + element_writer = self._writer.element(tag, attrib, nsmap, method, **_extra) + return _AsyncFileWriterElement(element_writer, self) + + async def _close(self, bint raise_on_error): + self._writer._close(raise_on_error) + data = self._buffer.collect() + if data: + await self._async_outfile.write(data) + if self._should_close: + await self._async_outfile.close() + + +@cython.final +@cython.internal +cdef class _AsyncFileWriterElement: + cdef _FileWriterElement _element_writer + cdef _AsyncIncrementalFileWriter _writer + + def __cinit__(self, _FileWriterElement element_writer not None, + _AsyncIncrementalFileWriter writer not None): + self._element_writer = element_writer + self._writer = writer + + async def __aenter__(self): + self._element_writer.__enter__() + data = self._writer._flush() + if data: + await self._writer._async_outfile.write(data) + + async def __aexit__(self, *args): + self._element_writer.__exit__(*args) + data = self._writer._flush() + if data: + await self._writer._async_outfile.write(data) + + +@cython.final +@cython.internal +@cython.freelist(8) +cdef class _FileWriterElement: + cdef _IncrementalFileWriter _writer + cdef object _element + cdef int _new_method + cdef int _old_method + + def __cinit__(self, _IncrementalFileWriter writer not None, element_config, int method): + self._writer = writer + self._element = element_config + self._new_method = method + self._old_method = writer._method + + def __enter__(self): + self._writer._method = self._new_method + self._writer._write_start_element(self._element) + + def __exit__(self, exc_type, exc_val, exc_tb): + self._writer._write_end_element(self._element) + self._writer._method = self._old_method + + +@cython.final +@cython.internal +@cython.freelist(8) +cdef class _MethodChanger: + cdef _IncrementalFileWriter _writer + cdef int _new_method + cdef int _old_method + cdef bint _entered + cdef bint _exited + + def __cinit__(self, _IncrementalFileWriter writer not None, int method): + self._writer = writer + self._new_method = method + self._old_method = writer._method + self._entered = False + self._exited = False + + def __enter__(self): + if self._entered: + raise LxmlSyntaxError("Inconsistent enter action in context manager") + self._writer._method = self._new_method + self._entered = True + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._exited: + raise LxmlSyntaxError("Inconsistent exit action in context manager") + if self._writer._method != self._new_method: + raise LxmlSyntaxError("Method changed outside of context manager") + self._writer._method = self._old_method + self._exited = True + + async def __aenter__(self): + # for your async convenience + return self.__enter__() + + async def __aexit__(self, *args): + # for your async convenience + return self.__exit__(*args) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xinclude.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xinclude.pxi new file mode 100644 index 0000000000000000000000000000000000000000..5c9ac45096efb2250e268dd2eed9ade07c2ca998 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xinclude.pxi @@ -0,0 +1,67 @@ +# XInclude processing + +from lxml.includes cimport xinclude + + +cdef class XIncludeError(LxmlError): + """Error during XInclude processing. + """ + + +cdef class XInclude: + """XInclude(self) + XInclude processor. + + Create an instance and call it on an Element to run XInclude + processing. + """ + cdef _ErrorLog _error_log + def __init__(self): + self._error_log = _ErrorLog() + + @property + def error_log(self): + assert self._error_log is not None, "XInclude instance not initialised" + return self._error_log.copy() + + def __call__(self, _Element node not None): + "__call__(self, node)" + # We cannot pass the XML_PARSE_NOXINCNODE option as this would free + # the XInclude nodes - there may still be Python references to them! + # Therefore, we allow XInclude nodes to be converted to + # XML_XINCLUDE_START nodes. XML_XINCLUDE_END nodes are added as + # siblings. Tree traversal will simply ignore them as they are not + # typed as elements. The included fragment is added between the two, + # i.e. as a sibling, which does not conflict with traversal. + cdef int result + _assertValidNode(node) + assert self._error_log is not None, "XInclude processor not initialised" + if node._doc._parser is not None: + parse_options = node._doc._parser._parse_options + context = node._doc._parser._getParserContext() + c_context = context + else: + parse_options = 0 + context = None + c_context = NULL + + self._error_log.connect() + if tree.LIBXML_VERSION < 20704 or not c_context: + __GLOBAL_PARSER_CONTEXT.pushImpliedContext(context) + with nogil: + orig_loader = _register_document_loader() + if c_context: + result = xinclude.xmlXIncludeProcessTreeFlagsData( + node._c_node, parse_options, c_context) + else: + result = xinclude.xmlXIncludeProcessTree(node._c_node) + _reset_document_loader(orig_loader) + if tree.LIBXML_VERSION < 20704 or not c_context: + __GLOBAL_PARSER_CONTEXT.popImpliedContext() + self._error_log.disconnect() + + if result == -1: + raise XIncludeError( + self._error_log._buildExceptionMessage( + "XInclude processing failed"), + self._error_log) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xmlerror.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xmlerror.pxi new file mode 100644 index 0000000000000000000000000000000000000000..3be24d21230b3c2c9a7d34bfc31ef969147877d6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xmlerror.pxi @@ -0,0 +1,1662 @@ +# DEBUG and error logging + +from lxml.includes cimport xmlerror +from lxml cimport cvarargs + +DEF GLOBAL_ERROR_LOG = "_GlobalErrorLog" +DEF XSLT_ERROR_LOG = "_XSLTErrorLog" + +# module level API functions + +def clear_error_log(): + """clear_error_log() + + Clear the global error log. Note that this log is already bound to a + fixed size. + + Note: since lxml 2.2, the global error log is local to a thread + and this function will only clear the global error log of the + current thread. + """ + _getThreadErrorLog(GLOBAL_ERROR_LOG).clear() + + +# setup for global log: + +cdef void _initThreadLogging() noexcept: + # Disable generic error lines from libxml2. + _connectGenericErrorLog(None) + + # Divert XSLT error messages to the global XSLT error log instead of stderr. + xslt.xsltSetGenericErrorFunc(NULL, _receiveXSLTError) + + +# Logging classes + +@cython.final +@cython.freelist(16) +cdef class _LogEntry: + """A log message entry from an error log. + + Attributes: + + - message: the message text + - domain: the domain ID (see lxml.etree.ErrorDomains) + - type: the message type ID (see lxml.etree.ErrorTypes) + - level: the log level ID (see lxml.etree.ErrorLevels) + - line: the line at which the message originated (if applicable) + - column: the character column at which the message originated (if applicable) + - filename: the name of the file in which the message originated (if applicable) + - path: the location in which the error was found (if available) + """ + cdef readonly int domain + cdef readonly int type + cdef readonly int level + cdef readonly long line + cdef readonly int column + cdef basestring _message + cdef basestring _filename + cdef char* _c_message + cdef xmlChar* _c_filename + cdef xmlChar* _c_path + + def __dealloc__(self): + tree.xmlFree(self._c_message) + tree.xmlFree(self._c_filename) + tree.xmlFree(self._c_path) + + @cython.final + cdef int _setError(self, const xmlerror.xmlError* error) except -1: + self.domain = error.domain + self.type = error.code + self.level = error.level + self.line = error.line + self.column = error.int2 + self._c_message = NULL + self._c_filename = NULL + self._c_path = NULL + if (error.message is NULL or + error.message[0] == b'\0' or + error.message[0] == b'\n' and error.message[1] == b'\0'): + self._message = "unknown error" + else: + self._message = None + self._c_message = tree.xmlStrdup( + error.message) + if not self._c_message: + raise MemoryError() + if error.file is NULL: + self._filename = '' + else: + self._filename = None + self._c_filename = tree.xmlStrdup( error.file) + if not self._c_filename: + raise MemoryError() + if error.node is not NULL: + self._c_path = tree.xmlGetNodePath( error.node) + c_line = tree.xmlGetLineNo( error.node) + if c_line > limits.INT_MAX: + self.line = c_line + + @cython.final + cdef _setGeneric(self, int domain, int type, int level, long line, + message, filename): + self.domain = domain + self.type = type + self.level = level + self.line = line + self.column = 0 + self._message = message + self._filename = filename + self._c_path = NULL + + def __repr__(self): + return "%s:%d:%d:%s:%s:%s: %s" % ( + self.filename, self.line, self.column, self.level_name, + self.domain_name, self.type_name, self.message) + + @property + def domain_name(self): + """The name of the error domain. See lxml.etree.ErrorDomains + """ + return ErrorDomains._getName(self.domain, "unknown") + + @property + def type_name(self): + """The name of the error type. See lxml.etree.ErrorTypes + """ + if self.domain == ErrorDomains.RELAXNGV: + getName = RelaxNGErrorTypes._getName + else: + getName = ErrorTypes._getName + return getName(self.type, "unknown") + + @property + def level_name(self): + """The name of the error level. See lxml.etree.ErrorLevels + """ + return ErrorLevels._getName(self.level, "unknown") + + @property + def message(self): + """The log message string. + """ + cdef size_t size + if self._message is not None: + return self._message + if self._c_message is NULL: + return None + size = cstring_h.strlen(self._c_message) + if size > 0 and self._c_message[size-1] == b'\n': + size -= 1 # strip EOL + # cannot use funicode() here because the message may contain + # byte encoded file paths etc. + try: + self._message = self._c_message[:size].decode('utf8') + except UnicodeDecodeError: + try: + self._message = self._c_message[:size].decode( + 'ascii', 'backslashreplace') + except UnicodeDecodeError: + self._message = '' + if self._c_message: + # clean up early + tree.xmlFree(self._c_message) + self._c_message = NULL + return self._message + + @property + def filename(self): + """The file path where the report originated, if any. + """ + if self._filename is None: + if self._c_filename is not NULL: + self._filename = _decodeFilename(self._c_filename) + # clean up early + tree.xmlFree(self._c_filename) + self._c_filename = NULL + return self._filename + + @property + def path(self): + """The XPath for the node where the error was detected. + """ + return funicode(self._c_path) if self._c_path is not NULL else None + + +cdef class _BaseErrorLog: + cdef _LogEntry _first_error + cdef readonly object last_error + def __init__(self, first_error, last_error): + self._first_error = first_error + self.last_error = last_error + + cpdef copy(self): + return _BaseErrorLog(self._first_error, self.last_error) + + def __repr__(self): + return '' + + cpdef receive(self, _LogEntry entry): + pass + + @cython.final + cdef int _receive(self, const xmlerror.xmlError* error) except -1: + cdef bint is_error + cdef _LogEntry entry + cdef _BaseErrorLog global_log + entry = _LogEntry.__new__(_LogEntry) + entry._setError(error) + is_error = error.level == xmlerror.XML_ERR_ERROR or \ + error.level == xmlerror.XML_ERR_FATAL + global_log = _getThreadErrorLog(GLOBAL_ERROR_LOG) + if global_log is not self: + global_log.receive(entry) + if is_error: + global_log.last_error = entry + self.receive(entry) + if is_error: + self.last_error = entry + + @cython.final + cdef int _receiveGeneric(self, int domain, int type, int level, long line, + message, filename) except -1: + cdef bint is_error + cdef _LogEntry entry + cdef _BaseErrorLog global_log + entry = _LogEntry.__new__(_LogEntry) + entry._setGeneric(domain, type, level, line, message, filename) + is_error = level == xmlerror.XML_ERR_ERROR or \ + level == xmlerror.XML_ERR_FATAL + global_log = _getThreadErrorLog(GLOBAL_ERROR_LOG) + if global_log is not self: + global_log.receive(entry) + if is_error: + global_log.last_error = entry + self.receive(entry) + if is_error: + self.last_error = entry + + @cython.final + cdef _buildParseException(self, exctype, default_message): + code = xmlerror.XML_ERR_INTERNAL_ERROR + if self._first_error is None: + return exctype(default_message, code, 0, 0) + message = self._first_error.message + if message: + code = self._first_error.type + else: + message = default_message + line = self._first_error.line + column = self._first_error.column + filename = self._first_error.filename + if line > 0: + if column > 0: + message = f"{message}, line {line}, column {column}" + else: + message = f"{message}, line {line}" + return exctype(message, code, line, column, filename) + + @cython.final + cdef _buildExceptionMessage(self, default_message): + if self._first_error is None: + return default_message + if self._first_error.message: + message = self._first_error.message + elif default_message is None: + return None + else: + message = default_message + if self._first_error.line > 0: + if self._first_error.column > 0: + message = f"{message}, line {self._first_error.line}, column {self._first_error.column}" + else: + message = f"{message}, line {self._first_error.line}" + return message + +cdef class _ListErrorLog(_BaseErrorLog): + "Immutable base version of a list based error log." + cdef list _entries + cdef int _offset + def __init__(self, entries, first_error, last_error): + if entries: + if first_error is None: + first_error = entries[0] + if last_error is None: + last_error = entries[-1] + _BaseErrorLog.__init__(self, first_error, last_error) + self._entries = entries + + cpdef copy(self): + """Creates a shallow copy of this error log. Reuses the list of + entries. + """ + cdef _ListErrorLog log = _ListErrorLog( + self._entries, self._first_error, self.last_error) + log._offset = self._offset + return log + + def __iter__(self): + entries = self._entries + if self._offset: + entries = islice(entries, self._offset) + return iter(entries) + + def __repr__(self): + return '\n'.join([repr(entry) for entry in self]) + + def __getitem__(self, index): + if self._offset: + index += self._offset + return self._entries[index] + + def __len__(self): + return len(self._entries) - self._offset + + def __contains__(self, error_type): + cdef Py_ssize_t i + for i, entry in enumerate(self._entries): + if i < self._offset: + continue + if entry.type == error_type: + return True + return False + + def __bool__(self): + return len(self._entries) > self._offset + + def filter_domains(self, domains): + """Filter the errors by the given domains and return a new error log + containing the matches. + """ + cdef _LogEntry entry + if isinstance(domains, int): + domains = (domains,) + filtered = [entry for entry in self if entry.domain in domains] + return _ListErrorLog(filtered, None, None) + + def filter_types(self, types): + """filter_types(self, types) + + Filter the errors by the given types and return a new error + log containing the matches. + """ + cdef _LogEntry entry + if isinstance(types, int): + types = (types,) + filtered = [entry for entry in self if entry.type in types] + return _ListErrorLog(filtered, None, None) + + def filter_levels(self, levels): + """filter_levels(self, levels) + + Filter the errors by the given error levels and return a new + error log containing the matches. + """ + cdef _LogEntry entry + if isinstance(levels, int): + levels = (levels,) + filtered = [entry for entry in self if entry.level in levels] + return _ListErrorLog(filtered, None, None) + + def filter_from_level(self, level): + """filter_from_level(self, level) + + Return a log with all messages of the requested level of worse. + """ + cdef _LogEntry entry + filtered = [entry for entry in self if entry.level >= level] + return _ListErrorLog(filtered, None, None) + + def filter_from_fatals(self): + """filter_from_fatals(self) + + Convenience method to get all fatal error messages. + """ + return self.filter_from_level(ErrorLevels.FATAL) + + def filter_from_errors(self): + """filter_from_errors(self) + + Convenience method to get all error messages or worse. + """ + return self.filter_from_level(ErrorLevels.ERROR) + + def filter_from_warnings(self): + """filter_from_warnings(self) + + Convenience method to get all warnings or worse. + """ + return self.filter_from_level(ErrorLevels.WARNING) + + +@cython.final +@cython.internal +cdef class _ErrorLogContext: + """ + Error log context for the 'with' statement. + Stores a reference to the current callbacks to allow for + recursively stacked log contexts. + """ + cdef xmlerror.xmlStructuredErrorFunc old_error_func + cdef void* old_error_context + cdef xmlerror.xmlGenericErrorFunc old_xslt_error_func + cdef void* old_xslt_error_context + cdef _BaseErrorLog old_xslt_error_log + + cdef int push_error_log(self, _BaseErrorLog log) except -1: + self.old_error_func = xmlerror.xmlStructuredError + self.old_error_context = xmlerror.xmlStructuredErrorContext + xmlerror.xmlSetStructuredErrorFunc( + log, _receiveError) + + # xslt.xsltSetGenericErrorFunc() is not thread-local => keep error log in TLS + self.old_xslt_error_func = xslt.xsltGenericError + self.old_xslt_error_context = xslt.xsltGenericErrorContext + self.old_xslt_error_log = _getThreadErrorLog(XSLT_ERROR_LOG) + _setThreadErrorLog(XSLT_ERROR_LOG, log) + xslt.xsltSetGenericErrorFunc( + NULL, _receiveXSLTError) + return 0 + + cdef int pop_error_log(self) except -1: + xmlerror.xmlSetStructuredErrorFunc( + self.old_error_context, self.old_error_func) + xslt.xsltSetGenericErrorFunc( + self.old_xslt_error_context, self.old_xslt_error_func) + _setThreadErrorLog(XSLT_ERROR_LOG, self.old_xslt_error_log) + self.old_xslt_error_log= None + return 0 + + +cdef class _ErrorLog(_ListErrorLog): + cdef list _logContexts + def __cinit__(self): + self._logContexts = [] + + def __init__(self): + _ListErrorLog.__init__(self, [], None, None) + + @cython.final + cdef int __enter__(self) except -1: + return self.connect() + + def __exit__(self, *args): + # TODO: make this a cdef function when Cython supports it + self.disconnect() + + @cython.final + cdef int connect(self) except -1: + self._first_error = None + del self._entries[:] + + cdef _ErrorLogContext context = _ErrorLogContext.__new__(_ErrorLogContext) + context.push_error_log(self) + self._logContexts.append(context) + return 0 + + @cython.final + cdef int disconnect(self) except -1: + cdef _ErrorLogContext context = self._logContexts.pop() + context.pop_error_log() + return 0 + + cpdef clear(self): + self._first_error = None + self.last_error = None + self._offset = 0 + del self._entries[:] + + cpdef copy(self): + """Creates a shallow copy of this error log and the list of entries. + """ + return _ListErrorLog( + self._entries[self._offset:], + self._first_error, self.last_error) + + def __iter__(self): + return iter(self._entries[self._offset:]) + + cpdef receive(self, _LogEntry entry): + if self._first_error is None and entry.level >= xmlerror.XML_ERR_ERROR: + self._first_error = entry + self._entries.append(entry) + +cdef class _DomainErrorLog(_ErrorLog): + def __init__(self, domains): + _ErrorLog.__init__(self) + self._accepted_domains = tuple(domains) + + cpdef receive(self, _LogEntry entry): + if entry.domain in self._accepted_domains: + _ErrorLog.receive(self, entry) + +cdef class _RotatingErrorLog(_ErrorLog): + cdef int _max_len + def __init__(self, max_len): + _ErrorLog.__init__(self) + self._max_len = max_len + + cpdef receive(self, _LogEntry entry): + if self._first_error is None and entry.level >= xmlerror.XML_ERR_ERROR: + self._first_error = entry + self._entries.append(entry) + + if len(self._entries) > self._max_len: + self._offset += 1 + if self._offset > self._max_len // 3: + offset = self._offset + self._offset = 0 + del self._entries[:offset] + +cdef class PyErrorLog(_BaseErrorLog): + """PyErrorLog(self, logger_name=None, logger=None) + A global error log that connects to the Python stdlib logging package. + + The constructor accepts an optional logger name or a readily + instantiated logger instance. + + If you want to change the mapping between libxml2's ErrorLevels and Python + logging levels, you can modify the level_map dictionary from a subclass. + + The default mapping is:: + + ErrorLevels.WARNING = logging.WARNING + ErrorLevels.ERROR = logging.ERROR + ErrorLevels.FATAL = logging.CRITICAL + + You can also override the method ``receive()`` that takes a LogEntry + object and calls ``self.log(log_entry, format_string, arg1, arg2, ...)`` + with appropriate data. + """ + cdef readonly dict level_map + cdef object _map_level + cdef object _log + def __init__(self, logger_name=None, logger=None): + _BaseErrorLog.__init__(self, None, None) + import logging + self.level_map = { + ErrorLevels.WARNING : logging.WARNING, + ErrorLevels.ERROR : logging.ERROR, + ErrorLevels.FATAL : logging.CRITICAL + } + self._map_level = self.level_map.get + if logger is None: + if logger_name: + logger = logging.getLogger(logger_name) + else: + logger = logging.getLogger() + self._log = logger.log + + cpdef copy(self): + """Dummy method that returns an empty error log. + """ + return _ListErrorLog([], None, None) + + def log(self, log_entry, message, *args): + """log(self, log_entry, message, *args) + + Called by the .receive() method to log a _LogEntry instance to + the Python logging system. This handles the error level + mapping. + + In the default implementation, the ``message`` argument + receives a complete log line, and there are no further + ``args``. To change the message format, it is best to + override the .receive() method instead of this one. + """ + self._log( + self._map_level(log_entry.level, 0), + message, *args + ) + + cpdef receive(self, _LogEntry log_entry): + """receive(self, log_entry) + + Receive a _LogEntry instance from the logging system. Calls + the .log() method with appropriate parameters:: + + self.log(log_entry, repr(log_entry)) + + You can override this method to provide your own log output + format. + """ + self.log(log_entry, repr(log_entry)) + +# thread-local, global list log to collect error output messages from +# libxml2/libxslt + +cdef _BaseErrorLog __GLOBAL_ERROR_LOG = _RotatingErrorLog(__MAX_LOG_SIZE) + + +cdef _BaseErrorLog _getThreadErrorLog(name): + """Retrieve the current error log with name 'name' of this thread.""" + cdef python.PyObject* thread_dict + thread_dict = python.PyThreadState_GetDict() + if thread_dict is NULL: + return __GLOBAL_ERROR_LOG + try: + return (thread_dict)[name] + except KeyError: + log = (thread_dict)[name] = \ + _RotatingErrorLog(__MAX_LOG_SIZE) + return log + + +cdef _setThreadErrorLog(name, _BaseErrorLog log): + """Set the global error log of this thread.""" + cdef python.PyObject* thread_dict + thread_dict = python.PyThreadState_GetDict() + if thread_dict is NULL: + if name == GLOBAL_ERROR_LOG: + global __GLOBAL_ERROR_LOG + __GLOBAL_ERROR_LOG = log + else: + (thread_dict)[name] = log + + +cdef __copyGlobalErrorLog(): + "Helper function for properties in exceptions." + return _getThreadErrorLog(GLOBAL_ERROR_LOG).copy() + + +def use_global_python_log(PyErrorLog log not None): + """use_global_python_log(log) + + Replace the global error log by an etree.PyErrorLog that uses the + standard Python logging package. + + Note that this disables access to the global error log from exceptions. + Parsers, XSLT etc. will continue to provide their normal local error log. + + Note: prior to lxml 2.2, this changed the error log globally. + Since lxml 2.2, the global error log is local to a thread and this + function will only set the global error log of the current thread. + """ + _setThreadErrorLog(GLOBAL_ERROR_LOG, log) + + +# local log functions: forward error to logger object +cdef void _forwardError(void* c_log_handler, const xmlerror.xmlError* error) noexcept with gil: + cdef _BaseErrorLog log_handler + if c_log_handler is not NULL: + log_handler = <_BaseErrorLog>c_log_handler + elif error.domain == xmlerror.XML_FROM_XSLT: + log_handler = _getThreadErrorLog(XSLT_ERROR_LOG) + else: + log_handler = _getThreadErrorLog(GLOBAL_ERROR_LOG) + log_handler._receive(error) + + +cdef void _receiveError(void* c_log_handler, const xmlerror.xmlError* error) noexcept nogil: + # no Python objects here, may be called without thread context ! + if __DEBUG: + _forwardError(c_log_handler, error) + + +cdef void _receiveXSLTError(void* c_log_handler, char* msg, ...) noexcept nogil: + # no Python objects here, may be called without thread context ! + cdef cvarargs.va_list args + cvarargs.va_start(args, msg) + _receiveGenericError(c_log_handler, xmlerror.XML_FROM_XSLT, msg, args) + cvarargs.va_end(args) + +cdef void _receiveRelaxNGParseError(void* c_log_handler, char* msg, ...) noexcept nogil: + # no Python objects here, may be called without thread context ! + cdef cvarargs.va_list args + cvarargs.va_start(args, msg) + _receiveGenericError(c_log_handler, xmlerror.XML_FROM_RELAXNGP, msg, args) + cvarargs.va_end(args) + +cdef void _receiveRelaxNGValidationError(void* c_log_handler, char* msg, ...) noexcept nogil: + # no Python objects here, may be called without thread context ! + cdef cvarargs.va_list args + cvarargs.va_start(args, msg) + _receiveGenericError(c_log_handler, xmlerror.XML_FROM_RELAXNGV, msg, args) + cvarargs.va_end(args) + +# dummy function: no log output at all +cdef void _nullGenericErrorFunc(void* ctxt, char* msg, ...) noexcept nogil: + pass + + +cdef void _connectGenericErrorLog(log, int c_domain=-1) noexcept: + cdef xmlerror.xmlGenericErrorFunc error_func = NULL + c_log = log + if c_domain == xmlerror.XML_FROM_XSLT: + error_func = _receiveXSLTError + elif c_domain == xmlerror.XML_FROM_RELAXNGP: + error_func = _receiveRelaxNGParseError + elif c_domain == xmlerror.XML_FROM_RELAXNGV: + error_func = _receiveRelaxNGValidationError + + if log is None or error_func is NULL: + c_log = NULL + error_func = _nullGenericErrorFunc + xmlerror.xmlSetGenericErrorFunc(c_log, error_func) + + +cdef void _receiveGenericError(void* c_log_handler, int c_domain, + char* msg, cvarargs.va_list args) noexcept nogil: + # no Python objects here, may be called without thread context ! + cdef xmlerror.xmlError c_error + cdef char* c_text + cdef char* c_message + cdef char* c_element + cdef char* c_pos + cdef char* c_name_pos + cdef char* c_str + cdef int text_size, element_size, format_count, c_int + if not __DEBUG or msg is NULL: + return + if msg[0] in b'\n\0': + return + + c_text = c_element = c_error.file = c_error.node = NULL + c_error.line = 0 + + # parse "NAME %s" chunks from the format string + c_name_pos = c_pos = msg + format_count = 0 + while c_pos[0]: + if c_pos[0] == b'%': + c_pos += 1 + if c_pos[0] == b's': # "%s" + format_count += 1 + c_str = cvarargs.va_charptr(args) + if c_pos == msg + 1: + c_text = c_str # msg == "%s..." + elif c_name_pos[0] == b'e': + if cstring_h.strncmp(c_name_pos, 'element %s', 10) == 0: + c_element = c_str + elif c_name_pos[0] == b'f': + if cstring_h.strncmp(c_name_pos, 'file %s', 7) == 0: + if cstring_h.strncmp('string://__STRING__XSLT', + c_str, 23) == 0: + c_str = '' + c_error.file = c_str + elif c_pos[0] == b'd': # "%d" + format_count += 1 + c_int = cvarargs.va_int(args) + if cstring_h.strncmp(c_name_pos, 'line %d', 7) == 0: + c_error.line = c_int + elif c_pos[0] != b'%': # "%%" == "%" + format_count += 1 + break # unexpected format or end of string => abort + elif c_pos[0] == b' ': + if c_pos[1] != b'%': + c_name_pos = c_pos + 1 + c_pos += 1 + + c_message = NULL + if c_text is NULL: + if c_element is not NULL and format_count == 1: + # special case: a single occurrence of 'element %s' + text_size = cstring_h.strlen(msg) + element_size = cstring_h.strlen(c_element) + c_message = stdlib.malloc( + (text_size + element_size + 1) * sizeof(char)) + stdio.sprintf(c_message, msg, c_element) + c_error.message = c_message + else: + c_error.message = '' + elif c_element is NULL: + c_error.message = c_text + else: + text_size = cstring_h.strlen(c_text) + element_size = cstring_h.strlen(c_element) + c_message = stdlib.malloc( + (text_size + 12 + element_size + 1) * sizeof(char)) + if c_message is NULL: + c_error.message = c_text + else: + stdio.sprintf(c_message, "%s, element '%s'", c_text, c_element) + c_error.message = c_message + + c_error.domain = c_domain + c_error.code = xmlerror.XML_ERR_OK # what else? + c_error.level = xmlerror.XML_ERR_ERROR # what else? + c_error.int2 = 0 + + _forwardError(c_log_handler, &c_error) + + if c_message is not NULL: + stdlib.free(c_message) + +################################################################################ +## CONSTANTS FROM "xmlerror.h" (or rather libxml-xmlerror.html) +################################################################################ + +cdef __initErrorConstants(): + "Called at setup time to parse the constants and build the classes below." + global __ERROR_LEVELS, __ERROR_DOMAINS, __PARSER_ERROR_TYPES, __RELAXNG_ERROR_TYPES + const_defs = ((ErrorLevels, __ERROR_LEVELS), + (ErrorDomains, __ERROR_DOMAINS), + (ErrorTypes, __PARSER_ERROR_TYPES), + (RelaxNGErrorTypes, __RELAXNG_ERROR_TYPES)) + + for cls, constants in const_defs: + reverse_dict = {} + cls._names = reverse_dict + cls._getName = reverse_dict.get + for line in constants.splitlines(): + if not line: + continue + name, value = line.split('=') + value = int(value) + setattr(cls, name, value) + reverse_dict[value] = name + + # discard the global string references after use + __ERROR_LEVELS = __ERROR_DOMAINS = __PARSER_ERROR_TYPES = __RELAXNG_ERROR_TYPES = None + + +class ErrorLevels(object): + """Libxml2 error levels""" + +class ErrorDomains(object): + """Libxml2 error domains""" + +class ErrorTypes(object): + """Libxml2 error types""" + +class RelaxNGErrorTypes(object): + """Libxml2 RelaxNG error types""" + + +# --- BEGIN: GENERATED CONSTANTS --- + +# This section is generated by the script 'update-error-constants.py'. + +cdef object __ERROR_LEVELS = """\ +NONE=0 +WARNING=1 +ERROR=2 +FATAL=3 +""" + +cdef object __ERROR_DOMAINS = """\ +NONE=0 +PARSER=1 +TREE=2 +NAMESPACE=3 +DTD=4 +HTML=5 +MEMORY=6 +OUTPUT=7 +IO=8 +FTP=9 +HTTP=10 +XINCLUDE=11 +XPATH=12 +XPOINTER=13 +REGEXP=14 +DATATYPE=15 +SCHEMASP=16 +SCHEMASV=17 +RELAXNGP=18 +RELAXNGV=19 +CATALOG=20 +C14N=21 +XSLT=22 +VALID=23 +CHECK=24 +WRITER=25 +MODULE=26 +I18N=27 +SCHEMATRONV=28 +BUFFER=29 +URI=30 +""" + +cdef object __PARSER_ERROR_TYPES = """\ +ERR_OK=0 +ERR_INTERNAL_ERROR=1 +ERR_NO_MEMORY=2 +ERR_DOCUMENT_START=3 +ERR_DOCUMENT_EMPTY=4 +ERR_DOCUMENT_END=5 +ERR_INVALID_HEX_CHARREF=6 +ERR_INVALID_DEC_CHARREF=7 +ERR_INVALID_CHARREF=8 +ERR_INVALID_CHAR=9 +ERR_CHARREF_AT_EOF=10 +ERR_CHARREF_IN_PROLOG=11 +ERR_CHARREF_IN_EPILOG=12 +ERR_CHARREF_IN_DTD=13 +ERR_ENTITYREF_AT_EOF=14 +ERR_ENTITYREF_IN_PROLOG=15 +ERR_ENTITYREF_IN_EPILOG=16 +ERR_ENTITYREF_IN_DTD=17 +ERR_PEREF_AT_EOF=18 +ERR_PEREF_IN_PROLOG=19 +ERR_PEREF_IN_EPILOG=20 +ERR_PEREF_IN_INT_SUBSET=21 +ERR_ENTITYREF_NO_NAME=22 +ERR_ENTITYREF_SEMICOL_MISSING=23 +ERR_PEREF_NO_NAME=24 +ERR_PEREF_SEMICOL_MISSING=25 +ERR_UNDECLARED_ENTITY=26 +WAR_UNDECLARED_ENTITY=27 +ERR_UNPARSED_ENTITY=28 +ERR_ENTITY_IS_EXTERNAL=29 +ERR_ENTITY_IS_PARAMETER=30 +ERR_UNKNOWN_ENCODING=31 +ERR_UNSUPPORTED_ENCODING=32 +ERR_STRING_NOT_STARTED=33 +ERR_STRING_NOT_CLOSED=34 +ERR_NS_DECL_ERROR=35 +ERR_ENTITY_NOT_STARTED=36 +ERR_ENTITY_NOT_FINISHED=37 +ERR_LT_IN_ATTRIBUTE=38 +ERR_ATTRIBUTE_NOT_STARTED=39 +ERR_ATTRIBUTE_NOT_FINISHED=40 +ERR_ATTRIBUTE_WITHOUT_VALUE=41 +ERR_ATTRIBUTE_REDEFINED=42 +ERR_LITERAL_NOT_STARTED=43 +ERR_LITERAL_NOT_FINISHED=44 +ERR_COMMENT_NOT_FINISHED=45 +ERR_PI_NOT_STARTED=46 +ERR_PI_NOT_FINISHED=47 +ERR_NOTATION_NOT_STARTED=48 +ERR_NOTATION_NOT_FINISHED=49 +ERR_ATTLIST_NOT_STARTED=50 +ERR_ATTLIST_NOT_FINISHED=51 +ERR_MIXED_NOT_STARTED=52 +ERR_MIXED_NOT_FINISHED=53 +ERR_ELEMCONTENT_NOT_STARTED=54 +ERR_ELEMCONTENT_NOT_FINISHED=55 +ERR_XMLDECL_NOT_STARTED=56 +ERR_XMLDECL_NOT_FINISHED=57 +ERR_CONDSEC_NOT_STARTED=58 +ERR_CONDSEC_NOT_FINISHED=59 +ERR_EXT_SUBSET_NOT_FINISHED=60 +ERR_DOCTYPE_NOT_FINISHED=61 +ERR_MISPLACED_CDATA_END=62 +ERR_CDATA_NOT_FINISHED=63 +ERR_RESERVED_XML_NAME=64 +ERR_SPACE_REQUIRED=65 +ERR_SEPARATOR_REQUIRED=66 +ERR_NMTOKEN_REQUIRED=67 +ERR_NAME_REQUIRED=68 +ERR_PCDATA_REQUIRED=69 +ERR_URI_REQUIRED=70 +ERR_PUBID_REQUIRED=71 +ERR_LT_REQUIRED=72 +ERR_GT_REQUIRED=73 +ERR_LTSLASH_REQUIRED=74 +ERR_EQUAL_REQUIRED=75 +ERR_TAG_NAME_MISMATCH=76 +ERR_TAG_NOT_FINISHED=77 +ERR_STANDALONE_VALUE=78 +ERR_ENCODING_NAME=79 +ERR_HYPHEN_IN_COMMENT=80 +ERR_INVALID_ENCODING=81 +ERR_EXT_ENTITY_STANDALONE=82 +ERR_CONDSEC_INVALID=83 +ERR_VALUE_REQUIRED=84 +ERR_NOT_WELL_BALANCED=85 +ERR_EXTRA_CONTENT=86 +ERR_ENTITY_CHAR_ERROR=87 +ERR_ENTITY_PE_INTERNAL=88 +ERR_ENTITY_LOOP=89 +ERR_ENTITY_BOUNDARY=90 +ERR_INVALID_URI=91 +ERR_URI_FRAGMENT=92 +WAR_CATALOG_PI=93 +ERR_NO_DTD=94 +ERR_CONDSEC_INVALID_KEYWORD=95 +ERR_VERSION_MISSING=96 +WAR_UNKNOWN_VERSION=97 +WAR_LANG_VALUE=98 +WAR_NS_URI=99 +WAR_NS_URI_RELATIVE=100 +ERR_MISSING_ENCODING=101 +WAR_SPACE_VALUE=102 +ERR_NOT_STANDALONE=103 +ERR_ENTITY_PROCESSING=104 +ERR_NOTATION_PROCESSING=105 +WAR_NS_COLUMN=106 +WAR_ENTITY_REDEFINED=107 +ERR_UNKNOWN_VERSION=108 +ERR_VERSION_MISMATCH=109 +ERR_NAME_TOO_LONG=110 +ERR_USER_STOP=111 +ERR_COMMENT_ABRUPTLY_ENDED=112 +WAR_ENCODING_MISMATCH=113 +ERR_RESOURCE_LIMIT=114 +ERR_ARGUMENT=115 +ERR_SYSTEM=116 +ERR_REDECL_PREDEF_ENTITY=117 +ERR_INT_SUBSET_NOT_FINISHED=118 +NS_ERR_XML_NAMESPACE=200 +NS_ERR_UNDEFINED_NAMESPACE=201 +NS_ERR_QNAME=202 +NS_ERR_ATTRIBUTE_REDEFINED=203 +NS_ERR_EMPTY=204 +NS_ERR_COLON=205 +DTD_ATTRIBUTE_DEFAULT=500 +DTD_ATTRIBUTE_REDEFINED=501 +DTD_ATTRIBUTE_VALUE=502 +DTD_CONTENT_ERROR=503 +DTD_CONTENT_MODEL=504 +DTD_CONTENT_NOT_DETERMINIST=505 +DTD_DIFFERENT_PREFIX=506 +DTD_ELEM_DEFAULT_NAMESPACE=507 +DTD_ELEM_NAMESPACE=508 +DTD_ELEM_REDEFINED=509 +DTD_EMPTY_NOTATION=510 +DTD_ENTITY_TYPE=511 +DTD_ID_FIXED=512 +DTD_ID_REDEFINED=513 +DTD_ID_SUBSET=514 +DTD_INVALID_CHILD=515 +DTD_INVALID_DEFAULT=516 +DTD_LOAD_ERROR=517 +DTD_MISSING_ATTRIBUTE=518 +DTD_MIXED_CORRUPT=519 +DTD_MULTIPLE_ID=520 +DTD_NO_DOC=521 +DTD_NO_DTD=522 +DTD_NO_ELEM_NAME=523 +DTD_NO_PREFIX=524 +DTD_NO_ROOT=525 +DTD_NOTATION_REDEFINED=526 +DTD_NOTATION_VALUE=527 +DTD_NOT_EMPTY=528 +DTD_NOT_PCDATA=529 +DTD_NOT_STANDALONE=530 +DTD_ROOT_NAME=531 +DTD_STANDALONE_WHITE_SPACE=532 +DTD_UNKNOWN_ATTRIBUTE=533 +DTD_UNKNOWN_ELEM=534 +DTD_UNKNOWN_ENTITY=535 +DTD_UNKNOWN_ID=536 +DTD_UNKNOWN_NOTATION=537 +DTD_STANDALONE_DEFAULTED=538 +DTD_XMLID_VALUE=539 +DTD_XMLID_TYPE=540 +DTD_DUP_TOKEN=541 +HTML_STRUCURE_ERROR=800 +HTML_UNKNOWN_TAG=801 +HTML_INCORRECTLY_OPENED_COMMENT=802 +RNGP_ANYNAME_ATTR_ANCESTOR=1000 +RNGP_ATTR_CONFLICT=1001 +RNGP_ATTRIBUTE_CHILDREN=1002 +RNGP_ATTRIBUTE_CONTENT=1003 +RNGP_ATTRIBUTE_EMPTY=1004 +RNGP_ATTRIBUTE_NOOP=1005 +RNGP_CHOICE_CONTENT=1006 +RNGP_CHOICE_EMPTY=1007 +RNGP_CREATE_FAILURE=1008 +RNGP_DATA_CONTENT=1009 +RNGP_DEF_CHOICE_AND_INTERLEAVE=1010 +RNGP_DEFINE_CREATE_FAILED=1011 +RNGP_DEFINE_EMPTY=1012 +RNGP_DEFINE_MISSING=1013 +RNGP_DEFINE_NAME_MISSING=1014 +RNGP_ELEM_CONTENT_EMPTY=1015 +RNGP_ELEM_CONTENT_ERROR=1016 +RNGP_ELEMENT_EMPTY=1017 +RNGP_ELEMENT_CONTENT=1018 +RNGP_ELEMENT_NAME=1019 +RNGP_ELEMENT_NO_CONTENT=1020 +RNGP_ELEM_TEXT_CONFLICT=1021 +RNGP_EMPTY=1022 +RNGP_EMPTY_CONSTRUCT=1023 +RNGP_EMPTY_CONTENT=1024 +RNGP_EMPTY_NOT_EMPTY=1025 +RNGP_ERROR_TYPE_LIB=1026 +RNGP_EXCEPT_EMPTY=1027 +RNGP_EXCEPT_MISSING=1028 +RNGP_EXCEPT_MULTIPLE=1029 +RNGP_EXCEPT_NO_CONTENT=1030 +RNGP_EXTERNALREF_EMTPY=1031 +RNGP_EXTERNAL_REF_FAILURE=1032 +RNGP_EXTERNALREF_RECURSE=1033 +RNGP_FORBIDDEN_ATTRIBUTE=1034 +RNGP_FOREIGN_ELEMENT=1035 +RNGP_GRAMMAR_CONTENT=1036 +RNGP_GRAMMAR_EMPTY=1037 +RNGP_GRAMMAR_MISSING=1038 +RNGP_GRAMMAR_NO_START=1039 +RNGP_GROUP_ATTR_CONFLICT=1040 +RNGP_HREF_ERROR=1041 +RNGP_INCLUDE_EMPTY=1042 +RNGP_INCLUDE_FAILURE=1043 +RNGP_INCLUDE_RECURSE=1044 +RNGP_INTERLEAVE_ADD=1045 +RNGP_INTERLEAVE_CREATE_FAILED=1046 +RNGP_INTERLEAVE_EMPTY=1047 +RNGP_INTERLEAVE_NO_CONTENT=1048 +RNGP_INVALID_DEFINE_NAME=1049 +RNGP_INVALID_URI=1050 +RNGP_INVALID_VALUE=1051 +RNGP_MISSING_HREF=1052 +RNGP_NAME_MISSING=1053 +RNGP_NEED_COMBINE=1054 +RNGP_NOTALLOWED_NOT_EMPTY=1055 +RNGP_NSNAME_ATTR_ANCESTOR=1056 +RNGP_NSNAME_NO_NS=1057 +RNGP_PARAM_FORBIDDEN=1058 +RNGP_PARAM_NAME_MISSING=1059 +RNGP_PARENTREF_CREATE_FAILED=1060 +RNGP_PARENTREF_NAME_INVALID=1061 +RNGP_PARENTREF_NO_NAME=1062 +RNGP_PARENTREF_NO_PARENT=1063 +RNGP_PARENTREF_NOT_EMPTY=1064 +RNGP_PARSE_ERROR=1065 +RNGP_PAT_ANYNAME_EXCEPT_ANYNAME=1066 +RNGP_PAT_ATTR_ATTR=1067 +RNGP_PAT_ATTR_ELEM=1068 +RNGP_PAT_DATA_EXCEPT_ATTR=1069 +RNGP_PAT_DATA_EXCEPT_ELEM=1070 +RNGP_PAT_DATA_EXCEPT_EMPTY=1071 +RNGP_PAT_DATA_EXCEPT_GROUP=1072 +RNGP_PAT_DATA_EXCEPT_INTERLEAVE=1073 +RNGP_PAT_DATA_EXCEPT_LIST=1074 +RNGP_PAT_DATA_EXCEPT_ONEMORE=1075 +RNGP_PAT_DATA_EXCEPT_REF=1076 +RNGP_PAT_DATA_EXCEPT_TEXT=1077 +RNGP_PAT_LIST_ATTR=1078 +RNGP_PAT_LIST_ELEM=1079 +RNGP_PAT_LIST_INTERLEAVE=1080 +RNGP_PAT_LIST_LIST=1081 +RNGP_PAT_LIST_REF=1082 +RNGP_PAT_LIST_TEXT=1083 +RNGP_PAT_NSNAME_EXCEPT_ANYNAME=1084 +RNGP_PAT_NSNAME_EXCEPT_NSNAME=1085 +RNGP_PAT_ONEMORE_GROUP_ATTR=1086 +RNGP_PAT_ONEMORE_INTERLEAVE_ATTR=1087 +RNGP_PAT_START_ATTR=1088 +RNGP_PAT_START_DATA=1089 +RNGP_PAT_START_EMPTY=1090 +RNGP_PAT_START_GROUP=1091 +RNGP_PAT_START_INTERLEAVE=1092 +RNGP_PAT_START_LIST=1093 +RNGP_PAT_START_ONEMORE=1094 +RNGP_PAT_START_TEXT=1095 +RNGP_PAT_START_VALUE=1096 +RNGP_PREFIX_UNDEFINED=1097 +RNGP_REF_CREATE_FAILED=1098 +RNGP_REF_CYCLE=1099 +RNGP_REF_NAME_INVALID=1100 +RNGP_REF_NO_DEF=1101 +RNGP_REF_NO_NAME=1102 +RNGP_REF_NOT_EMPTY=1103 +RNGP_START_CHOICE_AND_INTERLEAVE=1104 +RNGP_START_CONTENT=1105 +RNGP_START_EMPTY=1106 +RNGP_START_MISSING=1107 +RNGP_TEXT_EXPECTED=1108 +RNGP_TEXT_HAS_CHILD=1109 +RNGP_TYPE_MISSING=1110 +RNGP_TYPE_NOT_FOUND=1111 +RNGP_TYPE_VALUE=1112 +RNGP_UNKNOWN_ATTRIBUTE=1113 +RNGP_UNKNOWN_COMBINE=1114 +RNGP_UNKNOWN_CONSTRUCT=1115 +RNGP_UNKNOWN_TYPE_LIB=1116 +RNGP_URI_FRAGMENT=1117 +RNGP_URI_NOT_ABSOLUTE=1118 +RNGP_VALUE_EMPTY=1119 +RNGP_VALUE_NO_CONTENT=1120 +RNGP_XMLNS_NAME=1121 +RNGP_XML_NS=1122 +XPATH_EXPRESSION_OK=1200 +XPATH_NUMBER_ERROR=1201 +XPATH_UNFINISHED_LITERAL_ERROR=1202 +XPATH_START_LITERAL_ERROR=1203 +XPATH_VARIABLE_REF_ERROR=1204 +XPATH_UNDEF_VARIABLE_ERROR=1205 +XPATH_INVALID_PREDICATE_ERROR=1206 +XPATH_EXPR_ERROR=1207 +XPATH_UNCLOSED_ERROR=1208 +XPATH_UNKNOWN_FUNC_ERROR=1209 +XPATH_INVALID_OPERAND=1210 +XPATH_INVALID_TYPE=1211 +XPATH_INVALID_ARITY=1212 +XPATH_INVALID_CTXT_SIZE=1213 +XPATH_INVALID_CTXT_POSITION=1214 +XPATH_MEMORY_ERROR=1215 +XPTR_SYNTAX_ERROR=1216 +XPTR_RESOURCE_ERROR=1217 +XPTR_SUB_RESOURCE_ERROR=1218 +XPATH_UNDEF_PREFIX_ERROR=1219 +XPATH_ENCODING_ERROR=1220 +XPATH_INVALID_CHAR_ERROR=1221 +TREE_INVALID_HEX=1300 +TREE_INVALID_DEC=1301 +TREE_UNTERMINATED_ENTITY=1302 +TREE_NOT_UTF8=1303 +SAVE_NOT_UTF8=1400 +SAVE_CHAR_INVALID=1401 +SAVE_NO_DOCTYPE=1402 +SAVE_UNKNOWN_ENCODING=1403 +REGEXP_COMPILE_ERROR=1450 +IO_UNKNOWN=1500 +IO_EACCES=1501 +IO_EAGAIN=1502 +IO_EBADF=1503 +IO_EBADMSG=1504 +IO_EBUSY=1505 +IO_ECANCELED=1506 +IO_ECHILD=1507 +IO_EDEADLK=1508 +IO_EDOM=1509 +IO_EEXIST=1510 +IO_EFAULT=1511 +IO_EFBIG=1512 +IO_EINPROGRESS=1513 +IO_EINTR=1514 +IO_EINVAL=1515 +IO_EIO=1516 +IO_EISDIR=1517 +IO_EMFILE=1518 +IO_EMLINK=1519 +IO_EMSGSIZE=1520 +IO_ENAMETOOLONG=1521 +IO_ENFILE=1522 +IO_ENODEV=1523 +IO_ENOENT=1524 +IO_ENOEXEC=1525 +IO_ENOLCK=1526 +IO_ENOMEM=1527 +IO_ENOSPC=1528 +IO_ENOSYS=1529 +IO_ENOTDIR=1530 +IO_ENOTEMPTY=1531 +IO_ENOTSUP=1532 +IO_ENOTTY=1533 +IO_ENXIO=1534 +IO_EPERM=1535 +IO_EPIPE=1536 +IO_ERANGE=1537 +IO_EROFS=1538 +IO_ESPIPE=1539 +IO_ESRCH=1540 +IO_ETIMEDOUT=1541 +IO_EXDEV=1542 +IO_NETWORK_ATTEMPT=1543 +IO_ENCODER=1544 +IO_FLUSH=1545 +IO_WRITE=1546 +IO_NO_INPUT=1547 +IO_BUFFER_FULL=1548 +IO_LOAD_ERROR=1549 +IO_ENOTSOCK=1550 +IO_EISCONN=1551 +IO_ECONNREFUSED=1552 +IO_ENETUNREACH=1553 +IO_EADDRINUSE=1554 +IO_EALREADY=1555 +IO_EAFNOSUPPORT=1556 +IO_UNSUPPORTED_PROTOCOL=1557 +XINCLUDE_RECURSION=1600 +XINCLUDE_PARSE_VALUE=1601 +XINCLUDE_ENTITY_DEF_MISMATCH=1602 +XINCLUDE_NO_HREF=1603 +XINCLUDE_NO_FALLBACK=1604 +XINCLUDE_HREF_URI=1605 +XINCLUDE_TEXT_FRAGMENT=1606 +XINCLUDE_TEXT_DOCUMENT=1607 +XINCLUDE_INVALID_CHAR=1608 +XINCLUDE_BUILD_FAILED=1609 +XINCLUDE_UNKNOWN_ENCODING=1610 +XINCLUDE_MULTIPLE_ROOT=1611 +XINCLUDE_XPTR_FAILED=1612 +XINCLUDE_XPTR_RESULT=1613 +XINCLUDE_INCLUDE_IN_INCLUDE=1614 +XINCLUDE_FALLBACKS_IN_INCLUDE=1615 +XINCLUDE_FALLBACK_NOT_IN_INCLUDE=1616 +XINCLUDE_DEPRECATED_NS=1617 +XINCLUDE_FRAGMENT_ID=1618 +CATALOG_MISSING_ATTR=1650 +CATALOG_ENTRY_BROKEN=1651 +CATALOG_PREFER_VALUE=1652 +CATALOG_NOT_CATALOG=1653 +CATALOG_RECURSION=1654 +SCHEMAP_PREFIX_UNDEFINED=1700 +SCHEMAP_ATTRFORMDEFAULT_VALUE=1701 +SCHEMAP_ATTRGRP_NONAME_NOREF=1702 +SCHEMAP_ATTR_NONAME_NOREF=1703 +SCHEMAP_COMPLEXTYPE_NONAME_NOREF=1704 +SCHEMAP_ELEMFORMDEFAULT_VALUE=1705 +SCHEMAP_ELEM_NONAME_NOREF=1706 +SCHEMAP_EXTENSION_NO_BASE=1707 +SCHEMAP_FACET_NO_VALUE=1708 +SCHEMAP_FAILED_BUILD_IMPORT=1709 +SCHEMAP_GROUP_NONAME_NOREF=1710 +SCHEMAP_IMPORT_NAMESPACE_NOT_URI=1711 +SCHEMAP_IMPORT_REDEFINE_NSNAME=1712 +SCHEMAP_IMPORT_SCHEMA_NOT_URI=1713 +SCHEMAP_INVALID_BOOLEAN=1714 +SCHEMAP_INVALID_ENUM=1715 +SCHEMAP_INVALID_FACET=1716 +SCHEMAP_INVALID_FACET_VALUE=1717 +SCHEMAP_INVALID_MAXOCCURS=1718 +SCHEMAP_INVALID_MINOCCURS=1719 +SCHEMAP_INVALID_REF_AND_SUBTYPE=1720 +SCHEMAP_INVALID_WHITE_SPACE=1721 +SCHEMAP_NOATTR_NOREF=1722 +SCHEMAP_NOTATION_NO_NAME=1723 +SCHEMAP_NOTYPE_NOREF=1724 +SCHEMAP_REF_AND_SUBTYPE=1725 +SCHEMAP_RESTRICTION_NONAME_NOREF=1726 +SCHEMAP_SIMPLETYPE_NONAME=1727 +SCHEMAP_TYPE_AND_SUBTYPE=1728 +SCHEMAP_UNKNOWN_ALL_CHILD=1729 +SCHEMAP_UNKNOWN_ANYATTRIBUTE_CHILD=1730 +SCHEMAP_UNKNOWN_ATTR_CHILD=1731 +SCHEMAP_UNKNOWN_ATTRGRP_CHILD=1732 +SCHEMAP_UNKNOWN_ATTRIBUTE_GROUP=1733 +SCHEMAP_UNKNOWN_BASE_TYPE=1734 +SCHEMAP_UNKNOWN_CHOICE_CHILD=1735 +SCHEMAP_UNKNOWN_COMPLEXCONTENT_CHILD=1736 +SCHEMAP_UNKNOWN_COMPLEXTYPE_CHILD=1737 +SCHEMAP_UNKNOWN_ELEM_CHILD=1738 +SCHEMAP_UNKNOWN_EXTENSION_CHILD=1739 +SCHEMAP_UNKNOWN_FACET_CHILD=1740 +SCHEMAP_UNKNOWN_FACET_TYPE=1741 +SCHEMAP_UNKNOWN_GROUP_CHILD=1742 +SCHEMAP_UNKNOWN_IMPORT_CHILD=1743 +SCHEMAP_UNKNOWN_LIST_CHILD=1744 +SCHEMAP_UNKNOWN_NOTATION_CHILD=1745 +SCHEMAP_UNKNOWN_PROCESSCONTENT_CHILD=1746 +SCHEMAP_UNKNOWN_REF=1747 +SCHEMAP_UNKNOWN_RESTRICTION_CHILD=1748 +SCHEMAP_UNKNOWN_SCHEMAS_CHILD=1749 +SCHEMAP_UNKNOWN_SEQUENCE_CHILD=1750 +SCHEMAP_UNKNOWN_SIMPLECONTENT_CHILD=1751 +SCHEMAP_UNKNOWN_SIMPLETYPE_CHILD=1752 +SCHEMAP_UNKNOWN_TYPE=1753 +SCHEMAP_UNKNOWN_UNION_CHILD=1754 +SCHEMAP_ELEM_DEFAULT_FIXED=1755 +SCHEMAP_REGEXP_INVALID=1756 +SCHEMAP_FAILED_LOAD=1757 +SCHEMAP_NOTHING_TO_PARSE=1758 +SCHEMAP_NOROOT=1759 +SCHEMAP_REDEFINED_GROUP=1760 +SCHEMAP_REDEFINED_TYPE=1761 +SCHEMAP_REDEFINED_ELEMENT=1762 +SCHEMAP_REDEFINED_ATTRGROUP=1763 +SCHEMAP_REDEFINED_ATTR=1764 +SCHEMAP_REDEFINED_NOTATION=1765 +SCHEMAP_FAILED_PARSE=1766 +SCHEMAP_UNKNOWN_PREFIX=1767 +SCHEMAP_DEF_AND_PREFIX=1768 +SCHEMAP_UNKNOWN_INCLUDE_CHILD=1769 +SCHEMAP_INCLUDE_SCHEMA_NOT_URI=1770 +SCHEMAP_INCLUDE_SCHEMA_NO_URI=1771 +SCHEMAP_NOT_SCHEMA=1772 +SCHEMAP_UNKNOWN_MEMBER_TYPE=1773 +SCHEMAP_INVALID_ATTR_USE=1774 +SCHEMAP_RECURSIVE=1775 +SCHEMAP_SUPERNUMEROUS_LIST_ITEM_TYPE=1776 +SCHEMAP_INVALID_ATTR_COMBINATION=1777 +SCHEMAP_INVALID_ATTR_INLINE_COMBINATION=1778 +SCHEMAP_MISSING_SIMPLETYPE_CHILD=1779 +SCHEMAP_INVALID_ATTR_NAME=1780 +SCHEMAP_REF_AND_CONTENT=1781 +SCHEMAP_CT_PROPS_CORRECT_1=1782 +SCHEMAP_CT_PROPS_CORRECT_2=1783 +SCHEMAP_CT_PROPS_CORRECT_3=1784 +SCHEMAP_CT_PROPS_CORRECT_4=1785 +SCHEMAP_CT_PROPS_CORRECT_5=1786 +SCHEMAP_DERIVATION_OK_RESTRICTION_1=1787 +SCHEMAP_DERIVATION_OK_RESTRICTION_2_1_1=1788 +SCHEMAP_DERIVATION_OK_RESTRICTION_2_1_2=1789 +SCHEMAP_DERIVATION_OK_RESTRICTION_2_2=1790 +SCHEMAP_DERIVATION_OK_RESTRICTION_3=1791 +SCHEMAP_WILDCARD_INVALID_NS_MEMBER=1792 +SCHEMAP_INTERSECTION_NOT_EXPRESSIBLE=1793 +SCHEMAP_UNION_NOT_EXPRESSIBLE=1794 +SCHEMAP_SRC_IMPORT_3_1=1795 +SCHEMAP_SRC_IMPORT_3_2=1796 +SCHEMAP_DERIVATION_OK_RESTRICTION_4_1=1797 +SCHEMAP_DERIVATION_OK_RESTRICTION_4_2=1798 +SCHEMAP_DERIVATION_OK_RESTRICTION_4_3=1799 +SCHEMAP_COS_CT_EXTENDS_1_3=1800 +SCHEMAV_NOROOT=1801 +SCHEMAV_UNDECLAREDELEM=1802 +SCHEMAV_NOTTOPLEVEL=1803 +SCHEMAV_MISSING=1804 +SCHEMAV_WRONGELEM=1805 +SCHEMAV_NOTYPE=1806 +SCHEMAV_NOROLLBACK=1807 +SCHEMAV_ISABSTRACT=1808 +SCHEMAV_NOTEMPTY=1809 +SCHEMAV_ELEMCONT=1810 +SCHEMAV_HAVEDEFAULT=1811 +SCHEMAV_NOTNILLABLE=1812 +SCHEMAV_EXTRACONTENT=1813 +SCHEMAV_INVALIDATTR=1814 +SCHEMAV_INVALIDELEM=1815 +SCHEMAV_NOTDETERMINIST=1816 +SCHEMAV_CONSTRUCT=1817 +SCHEMAV_INTERNAL=1818 +SCHEMAV_NOTSIMPLE=1819 +SCHEMAV_ATTRUNKNOWN=1820 +SCHEMAV_ATTRINVALID=1821 +SCHEMAV_VALUE=1822 +SCHEMAV_FACET=1823 +SCHEMAV_CVC_DATATYPE_VALID_1_2_1=1824 +SCHEMAV_CVC_DATATYPE_VALID_1_2_2=1825 +SCHEMAV_CVC_DATATYPE_VALID_1_2_3=1826 +SCHEMAV_CVC_TYPE_3_1_1=1827 +SCHEMAV_CVC_TYPE_3_1_2=1828 +SCHEMAV_CVC_FACET_VALID=1829 +SCHEMAV_CVC_LENGTH_VALID=1830 +SCHEMAV_CVC_MINLENGTH_VALID=1831 +SCHEMAV_CVC_MAXLENGTH_VALID=1832 +SCHEMAV_CVC_MININCLUSIVE_VALID=1833 +SCHEMAV_CVC_MAXINCLUSIVE_VALID=1834 +SCHEMAV_CVC_MINEXCLUSIVE_VALID=1835 +SCHEMAV_CVC_MAXEXCLUSIVE_VALID=1836 +SCHEMAV_CVC_TOTALDIGITS_VALID=1837 +SCHEMAV_CVC_FRACTIONDIGITS_VALID=1838 +SCHEMAV_CVC_PATTERN_VALID=1839 +SCHEMAV_CVC_ENUMERATION_VALID=1840 +SCHEMAV_CVC_COMPLEX_TYPE_2_1=1841 +SCHEMAV_CVC_COMPLEX_TYPE_2_2=1842 +SCHEMAV_CVC_COMPLEX_TYPE_2_3=1843 +SCHEMAV_CVC_COMPLEX_TYPE_2_4=1844 +SCHEMAV_CVC_ELT_1=1845 +SCHEMAV_CVC_ELT_2=1846 +SCHEMAV_CVC_ELT_3_1=1847 +SCHEMAV_CVC_ELT_3_2_1=1848 +SCHEMAV_CVC_ELT_3_2_2=1849 +SCHEMAV_CVC_ELT_4_1=1850 +SCHEMAV_CVC_ELT_4_2=1851 +SCHEMAV_CVC_ELT_4_3=1852 +SCHEMAV_CVC_ELT_5_1_1=1853 +SCHEMAV_CVC_ELT_5_1_2=1854 +SCHEMAV_CVC_ELT_5_2_1=1855 +SCHEMAV_CVC_ELT_5_2_2_1=1856 +SCHEMAV_CVC_ELT_5_2_2_2_1=1857 +SCHEMAV_CVC_ELT_5_2_2_2_2=1858 +SCHEMAV_CVC_ELT_6=1859 +SCHEMAV_CVC_ELT_7=1860 +SCHEMAV_CVC_ATTRIBUTE_1=1861 +SCHEMAV_CVC_ATTRIBUTE_2=1862 +SCHEMAV_CVC_ATTRIBUTE_3=1863 +SCHEMAV_CVC_ATTRIBUTE_4=1864 +SCHEMAV_CVC_COMPLEX_TYPE_3_1=1865 +SCHEMAV_CVC_COMPLEX_TYPE_3_2_1=1866 +SCHEMAV_CVC_COMPLEX_TYPE_3_2_2=1867 +SCHEMAV_CVC_COMPLEX_TYPE_4=1868 +SCHEMAV_CVC_COMPLEX_TYPE_5_1=1869 +SCHEMAV_CVC_COMPLEX_TYPE_5_2=1870 +SCHEMAV_ELEMENT_CONTENT=1871 +SCHEMAV_DOCUMENT_ELEMENT_MISSING=1872 +SCHEMAV_CVC_COMPLEX_TYPE_1=1873 +SCHEMAV_CVC_AU=1874 +SCHEMAV_CVC_TYPE_1=1875 +SCHEMAV_CVC_TYPE_2=1876 +SCHEMAV_CVC_IDC=1877 +SCHEMAV_CVC_WILDCARD=1878 +SCHEMAV_MISC=1879 +XPTR_UNKNOWN_SCHEME=1900 +XPTR_CHILDSEQ_START=1901 +XPTR_EVAL_FAILED=1902 +XPTR_EXTRA_OBJECTS=1903 +C14N_CREATE_CTXT=1950 +C14N_REQUIRES_UTF8=1951 +C14N_CREATE_STACK=1952 +C14N_INVALID_NODE=1953 +C14N_UNKNOW_NODE=1954 +C14N_RELATIVE_NAMESPACE=1955 +FTP_PASV_ANSWER=2000 +FTP_EPSV_ANSWER=2001 +FTP_ACCNT=2002 +FTP_URL_SYNTAX=2003 +HTTP_URL_SYNTAX=2020 +HTTP_USE_IP=2021 +HTTP_UNKNOWN_HOST=2022 +SCHEMAP_SRC_SIMPLE_TYPE_1=3000 +SCHEMAP_SRC_SIMPLE_TYPE_2=3001 +SCHEMAP_SRC_SIMPLE_TYPE_3=3002 +SCHEMAP_SRC_SIMPLE_TYPE_4=3003 +SCHEMAP_SRC_RESOLVE=3004 +SCHEMAP_SRC_RESTRICTION_BASE_OR_SIMPLETYPE=3005 +SCHEMAP_SRC_LIST_ITEMTYPE_OR_SIMPLETYPE=3006 +SCHEMAP_SRC_UNION_MEMBERTYPES_OR_SIMPLETYPES=3007 +SCHEMAP_ST_PROPS_CORRECT_1=3008 +SCHEMAP_ST_PROPS_CORRECT_2=3009 +SCHEMAP_ST_PROPS_CORRECT_3=3010 +SCHEMAP_COS_ST_RESTRICTS_1_1=3011 +SCHEMAP_COS_ST_RESTRICTS_1_2=3012 +SCHEMAP_COS_ST_RESTRICTS_1_3_1=3013 +SCHEMAP_COS_ST_RESTRICTS_1_3_2=3014 +SCHEMAP_COS_ST_RESTRICTS_2_1=3015 +SCHEMAP_COS_ST_RESTRICTS_2_3_1_1=3016 +SCHEMAP_COS_ST_RESTRICTS_2_3_1_2=3017 +SCHEMAP_COS_ST_RESTRICTS_2_3_2_1=3018 +SCHEMAP_COS_ST_RESTRICTS_2_3_2_2=3019 +SCHEMAP_COS_ST_RESTRICTS_2_3_2_3=3020 +SCHEMAP_COS_ST_RESTRICTS_2_3_2_4=3021 +SCHEMAP_COS_ST_RESTRICTS_2_3_2_5=3022 +SCHEMAP_COS_ST_RESTRICTS_3_1=3023 +SCHEMAP_COS_ST_RESTRICTS_3_3_1=3024 +SCHEMAP_COS_ST_RESTRICTS_3_3_1_2=3025 +SCHEMAP_COS_ST_RESTRICTS_3_3_2_2=3026 +SCHEMAP_COS_ST_RESTRICTS_3_3_2_1=3027 +SCHEMAP_COS_ST_RESTRICTS_3_3_2_3=3028 +SCHEMAP_COS_ST_RESTRICTS_3_3_2_4=3029 +SCHEMAP_COS_ST_RESTRICTS_3_3_2_5=3030 +SCHEMAP_COS_ST_DERIVED_OK_2_1=3031 +SCHEMAP_COS_ST_DERIVED_OK_2_2=3032 +SCHEMAP_S4S_ELEM_NOT_ALLOWED=3033 +SCHEMAP_S4S_ELEM_MISSING=3034 +SCHEMAP_S4S_ATTR_NOT_ALLOWED=3035 +SCHEMAP_S4S_ATTR_MISSING=3036 +SCHEMAP_S4S_ATTR_INVALID_VALUE=3037 +SCHEMAP_SRC_ELEMENT_1=3038 +SCHEMAP_SRC_ELEMENT_2_1=3039 +SCHEMAP_SRC_ELEMENT_2_2=3040 +SCHEMAP_SRC_ELEMENT_3=3041 +SCHEMAP_P_PROPS_CORRECT_1=3042 +SCHEMAP_P_PROPS_CORRECT_2_1=3043 +SCHEMAP_P_PROPS_CORRECT_2_2=3044 +SCHEMAP_E_PROPS_CORRECT_2=3045 +SCHEMAP_E_PROPS_CORRECT_3=3046 +SCHEMAP_E_PROPS_CORRECT_4=3047 +SCHEMAP_E_PROPS_CORRECT_5=3048 +SCHEMAP_E_PROPS_CORRECT_6=3049 +SCHEMAP_SRC_INCLUDE=3050 +SCHEMAP_SRC_ATTRIBUTE_1=3051 +SCHEMAP_SRC_ATTRIBUTE_2=3052 +SCHEMAP_SRC_ATTRIBUTE_3_1=3053 +SCHEMAP_SRC_ATTRIBUTE_3_2=3054 +SCHEMAP_SRC_ATTRIBUTE_4=3055 +SCHEMAP_NO_XMLNS=3056 +SCHEMAP_NO_XSI=3057 +SCHEMAP_COS_VALID_DEFAULT_1=3058 +SCHEMAP_COS_VALID_DEFAULT_2_1=3059 +SCHEMAP_COS_VALID_DEFAULT_2_2_1=3060 +SCHEMAP_COS_VALID_DEFAULT_2_2_2=3061 +SCHEMAP_CVC_SIMPLE_TYPE=3062 +SCHEMAP_COS_CT_EXTENDS_1_1=3063 +SCHEMAP_SRC_IMPORT_1_1=3064 +SCHEMAP_SRC_IMPORT_1_2=3065 +SCHEMAP_SRC_IMPORT_2=3066 +SCHEMAP_SRC_IMPORT_2_1=3067 +SCHEMAP_SRC_IMPORT_2_2=3068 +SCHEMAP_INTERNAL=3069 +SCHEMAP_NOT_DETERMINISTIC=3070 +SCHEMAP_SRC_ATTRIBUTE_GROUP_1=3071 +SCHEMAP_SRC_ATTRIBUTE_GROUP_2=3072 +SCHEMAP_SRC_ATTRIBUTE_GROUP_3=3073 +SCHEMAP_MG_PROPS_CORRECT_1=3074 +SCHEMAP_MG_PROPS_CORRECT_2=3075 +SCHEMAP_SRC_CT_1=3076 +SCHEMAP_DERIVATION_OK_RESTRICTION_2_1_3=3077 +SCHEMAP_AU_PROPS_CORRECT_2=3078 +SCHEMAP_A_PROPS_CORRECT_2=3079 +SCHEMAP_C_PROPS_CORRECT=3080 +SCHEMAP_SRC_REDEFINE=3081 +SCHEMAP_SRC_IMPORT=3082 +SCHEMAP_WARN_SKIP_SCHEMA=3083 +SCHEMAP_WARN_UNLOCATED_SCHEMA=3084 +SCHEMAP_WARN_ATTR_REDECL_PROH=3085 +SCHEMAP_WARN_ATTR_POINTLESS_PROH=3086 +SCHEMAP_AG_PROPS_CORRECT=3087 +SCHEMAP_COS_CT_EXTENDS_1_2=3088 +SCHEMAP_AU_PROPS_CORRECT=3089 +SCHEMAP_A_PROPS_CORRECT_3=3090 +SCHEMAP_COS_ALL_LIMITED=3091 +SCHEMATRONV_ASSERT=4000 +SCHEMATRONV_REPORT=4001 +MODULE_OPEN=4900 +MODULE_CLOSE=4901 +CHECK_FOUND_ELEMENT=5000 +CHECK_FOUND_ATTRIBUTE=5001 +CHECK_FOUND_TEXT=5002 +CHECK_FOUND_CDATA=5003 +CHECK_FOUND_ENTITYREF=5004 +CHECK_FOUND_ENTITY=5005 +CHECK_FOUND_PI=5006 +CHECK_FOUND_COMMENT=5007 +CHECK_FOUND_DOCTYPE=5008 +CHECK_FOUND_FRAGMENT=5009 +CHECK_FOUND_NOTATION=5010 +CHECK_UNKNOWN_NODE=5011 +CHECK_ENTITY_TYPE=5012 +CHECK_NO_PARENT=5013 +CHECK_NO_DOC=5014 +CHECK_NO_NAME=5015 +CHECK_NO_ELEM=5016 +CHECK_WRONG_DOC=5017 +CHECK_NO_PREV=5018 +CHECK_WRONG_PREV=5019 +CHECK_NO_NEXT=5020 +CHECK_WRONG_NEXT=5021 +CHECK_NOT_DTD=5022 +CHECK_NOT_ATTR=5023 +CHECK_NOT_ATTR_DECL=5024 +CHECK_NOT_ELEM_DECL=5025 +CHECK_NOT_ENTITY_DECL=5026 +CHECK_NOT_NS_DECL=5027 +CHECK_NO_HREF=5028 +CHECK_WRONG_PARENT=5029 +CHECK_NS_SCOPE=5030 +CHECK_NS_ANCESTOR=5031 +CHECK_NOT_UTF8=5032 +CHECK_NO_DICT=5033 +CHECK_NOT_NCNAME=5034 +CHECK_OUTSIDE_DICT=5035 +CHECK_WRONG_NAME=5036 +CHECK_NAME_NOT_NULL=5037 +I18N_NO_NAME=6000 +I18N_NO_HANDLER=6001 +I18N_EXCESS_HANDLER=6002 +I18N_CONV_FAILED=6003 +I18N_NO_OUTPUT=6004 +BUF_OVERFLOW=7000 +""" + +cdef object __RELAXNG_ERROR_TYPES = """\ +RELAXNG_OK=0 +RELAXNG_ERR_MEMORY=1 +RELAXNG_ERR_TYPE=2 +RELAXNG_ERR_TYPEVAL=3 +RELAXNG_ERR_DUPID=4 +RELAXNG_ERR_TYPECMP=5 +RELAXNG_ERR_NOSTATE=6 +RELAXNG_ERR_NODEFINE=7 +RELAXNG_ERR_LISTEXTRA=8 +RELAXNG_ERR_LISTEMPTY=9 +RELAXNG_ERR_INTERNODATA=10 +RELAXNG_ERR_INTERSEQ=11 +RELAXNG_ERR_INTEREXTRA=12 +RELAXNG_ERR_ELEMNAME=13 +RELAXNG_ERR_ATTRNAME=14 +RELAXNG_ERR_ELEMNONS=15 +RELAXNG_ERR_ATTRNONS=16 +RELAXNG_ERR_ELEMWRONGNS=17 +RELAXNG_ERR_ATTRWRONGNS=18 +RELAXNG_ERR_ELEMEXTRANS=19 +RELAXNG_ERR_ATTREXTRANS=20 +RELAXNG_ERR_ELEMNOTEMPTY=21 +RELAXNG_ERR_NOELEM=22 +RELAXNG_ERR_NOTELEM=23 +RELAXNG_ERR_ATTRVALID=24 +RELAXNG_ERR_CONTENTVALID=25 +RELAXNG_ERR_EXTRACONTENT=26 +RELAXNG_ERR_INVALIDATTR=27 +RELAXNG_ERR_DATAELEM=28 +RELAXNG_ERR_VALELEM=29 +RELAXNG_ERR_LISTELEM=30 +RELAXNG_ERR_DATATYPE=31 +RELAXNG_ERR_VALUE=32 +RELAXNG_ERR_LIST=33 +RELAXNG_ERR_NOGRAMMAR=34 +RELAXNG_ERR_EXTRADATA=35 +RELAXNG_ERR_LACKDATA=36 +RELAXNG_ERR_INTERNAL=37 +RELAXNG_ERR_ELEMWRONG=38 +RELAXNG_ERR_TEXTWRONG=39 +""" +# --- END: GENERATED CONSTANTS --- + +__initErrorConstants() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xpath.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xpath.pxi new file mode 100644 index 0000000000000000000000000000000000000000..352f63134734780e5a9c869ccb59b4cb4e4ade40 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xpath.pxi @@ -0,0 +1,487 @@ +# XPath evaluation + +class XPathSyntaxError(LxmlSyntaxError, XPathError): + pass + +################################################################################ +# XPath + +cdef object _XPATH_SYNTAX_ERRORS = ( + xmlerror.XML_XPATH_NUMBER_ERROR, + xmlerror.XML_XPATH_UNFINISHED_LITERAL_ERROR, + xmlerror.XML_XPATH_VARIABLE_REF_ERROR, + xmlerror.XML_XPATH_INVALID_PREDICATE_ERROR, + xmlerror.XML_XPATH_UNCLOSED_ERROR, + xmlerror.XML_XPATH_INVALID_CHAR_ERROR +) + +cdef object _XPATH_EVAL_ERRORS = ( + xmlerror.XML_XPATH_UNDEF_VARIABLE_ERROR, + xmlerror.XML_XPATH_UNDEF_PREFIX_ERROR, + xmlerror.XML_XPATH_UNKNOWN_FUNC_ERROR, + xmlerror.XML_XPATH_INVALID_OPERAND, + xmlerror.XML_XPATH_INVALID_TYPE, + xmlerror.XML_XPATH_INVALID_ARITY, + xmlerror.XML_XPATH_INVALID_CTXT_SIZE, + xmlerror.XML_XPATH_INVALID_CTXT_POSITION +) + +cdef int _register_xpath_function(void* ctxt, name_utf, ns_utf) noexcept: + if ns_utf is None: + return xpath.xmlXPathRegisterFunc( + ctxt, _xcstr(name_utf), + _xpath_function_call) + else: + return xpath.xmlXPathRegisterFuncNS( + ctxt, _xcstr(name_utf), _xcstr(ns_utf), + _xpath_function_call) + +cdef int _unregister_xpath_function(void* ctxt, name_utf, ns_utf) noexcept: + if ns_utf is None: + return xpath.xmlXPathRegisterFunc( + ctxt, _xcstr(name_utf), NULL) + else: + return xpath.xmlXPathRegisterFuncNS( + ctxt, _xcstr(name_utf), _xcstr(ns_utf), NULL) + + +@cython.final +@cython.internal +cdef class _XPathContext(_BaseContext): + cdef object _variables + def __init__(self, namespaces, extensions, error_log, enable_regexp, variables, + build_smart_strings): + self._variables = variables + _BaseContext.__init__(self, namespaces, extensions, error_log, enable_regexp, + build_smart_strings) + + cdef set_context(self, xpath.xmlXPathContext* xpathCtxt): + self._set_xpath_context(xpathCtxt) + # This would be a good place to set up the XPath parser dict, but + # we cannot use the current thread dict as we do not know which + # thread will execute the XPath evaluator - so, no dict for now. + self.registerLocalNamespaces() + self.registerLocalFunctions(xpathCtxt, _register_xpath_function) + + cdef register_context(self, _Document doc): + self._register_context(doc) + self.registerGlobalNamespaces() + self.registerGlobalFunctions(self._xpathCtxt, _register_xpath_function) + self.registerExsltFunctions() + if self._variables is not None: + self.registerVariables(self._variables) + + cdef unregister_context(self): + self.unregisterGlobalFunctions( + self._xpathCtxt, _unregister_xpath_function) + self.unregisterGlobalNamespaces() + xpath.xmlXPathRegisteredVariablesCleanup(self._xpathCtxt) + self._cleanup_context() + + cdef void registerExsltFunctions(self) noexcept: + if xslt.LIBXSLT_VERSION < 10125: + # we'd only execute dummy functions anyway + return + tree.xmlHashScan( + self._xpathCtxt.nsHash, _registerExsltFunctionsForNamespaces, + self._xpathCtxt) + + cdef registerVariables(self, variable_dict): + for name, value in variable_dict.items(): + name_utf = self._to_utf(name) + xpath.xmlXPathRegisterVariable( + self._xpathCtxt, _xcstr(name_utf), _wrapXPathObject(value, None, None)) + + cdef registerVariable(self, name, value): + name_utf = self._to_utf(name) + xpath.xmlXPathRegisterVariable( + self._xpathCtxt, _xcstr(name_utf), _wrapXPathObject(value, None, None)) + + +cdef void _registerExsltFunctionsForNamespaces( + void* _c_href, void* _ctxt, const_xmlChar* c_prefix) noexcept: + c_href = _c_href + ctxt = _ctxt + + if tree.xmlStrcmp(c_href, xslt.EXSLT_DATE_NAMESPACE) == 0: + xslt.exsltDateXpathCtxtRegister(ctxt, c_prefix) + elif tree.xmlStrcmp(c_href, xslt.EXSLT_SETS_NAMESPACE) == 0: + xslt.exsltSetsXpathCtxtRegister(ctxt, c_prefix) + elif tree.xmlStrcmp(c_href, xslt.EXSLT_MATH_NAMESPACE) == 0: + xslt.exsltMathXpathCtxtRegister(ctxt, c_prefix) + elif tree.xmlStrcmp(c_href, xslt.EXSLT_STRINGS_NAMESPACE) == 0: + xslt.exsltStrXpathCtxtRegister(ctxt, c_prefix) + + +cdef class _XPathEvaluatorBase: + cdef xpath.xmlXPathContext* _xpathCtxt + cdef _XPathContext _context + cdef python.PyThread_type_lock _eval_lock + cdef _ErrorLog _error_log + def __cinit__(self): + self._xpathCtxt = NULL + if config.ENABLE_THREADING: + self._eval_lock = python.PyThread_allocate_lock() + if self._eval_lock is NULL: + raise MemoryError() + self._error_log = _ErrorLog() + + def __init__(self, namespaces, extensions, enable_regexp, + smart_strings): + self._context = _XPathContext(namespaces, extensions, self._error_log, + enable_regexp, None, smart_strings) + + @property + def error_log(self): + assert self._error_log is not None, "XPath evaluator not initialised" + return self._error_log.copy() + + def __dealloc__(self): + if self._xpathCtxt is not NULL: + xpath.xmlXPathFreeContext(self._xpathCtxt) + if config.ENABLE_THREADING: + if self._eval_lock is not NULL: + python.PyThread_free_lock(self._eval_lock) + + cdef set_context(self, xpath.xmlXPathContext* xpathCtxt): + self._xpathCtxt = xpathCtxt + self._context.set_context(xpathCtxt) + + cdef bint _checkAbsolutePath(self, char* path) noexcept: + cdef char c + if path is NULL: + return 0 + c = path[0] + while c == c' ' or c == c'\t': + path = path + 1 + c = path[0] + return c == c'/' + + @cython.final + cdef int _lock(self) except -1: + cdef int result + if config.ENABLE_THREADING and self._eval_lock != NULL: + with nogil: + result = python.PyThread_acquire_lock( + self._eval_lock, python.WAIT_LOCK) + if result == 0: + raise XPathError, "XPath evaluator locking failed" + return 0 + + @cython.final + cdef void _unlock(self) noexcept: + if config.ENABLE_THREADING and self._eval_lock != NULL: + python.PyThread_release_lock(self._eval_lock) + + cdef _build_parse_error(self): + cdef _BaseErrorLog entries + entries = self._error_log.filter_types(_XPATH_SYNTAX_ERRORS) + if entries: + message = entries._buildExceptionMessage(None) + if message is not None: + return XPathSyntaxError(message, self._error_log) + return XPathSyntaxError( + self._error_log._buildExceptionMessage("Error in xpath expression"), + self._error_log) + + cdef _build_eval_error(self): + cdef _BaseErrorLog entries + entries = self._error_log.filter_types(_XPATH_EVAL_ERRORS) + if not entries: + entries = self._error_log.filter_types(_XPATH_SYNTAX_ERRORS) + if entries: + message = entries._buildExceptionMessage(None) + if message is not None: + return XPathEvalError(message, self._error_log) + return XPathEvalError( + self._error_log._buildExceptionMessage("Error in xpath expression"), + self._error_log) + + cdef object _handle_result(self, xpath.xmlXPathObject* xpathObj, _Document doc): + if self._context._exc._has_raised(): + if xpathObj is not NULL: + _freeXPathObject(xpathObj) + xpathObj = NULL + self._context._release_temp_refs() + self._context._exc._raise_if_stored() + + if xpathObj is NULL: + self._context._release_temp_refs() + raise self._build_eval_error() + + try: + result = _unwrapXPathObject(xpathObj, doc, self._context) + finally: + _freeXPathObject(xpathObj) + self._context._release_temp_refs() + + return result + + +cdef class XPathElementEvaluator(_XPathEvaluatorBase): + """XPathElementEvaluator(self, element, namespaces=None, extensions=None, regexp=True, smart_strings=True) + Create an XPath evaluator for an element. + + Absolute XPath expressions (starting with '/') will be evaluated against + the ElementTree as returned by getroottree(). + + Additional namespace declarations can be passed with the + 'namespace' keyword argument. EXSLT regular expression support + can be disabled with the 'regexp' boolean keyword (defaults to + True). Smart strings will be returned for string results unless + you pass ``smart_strings=False``. + """ + cdef _Element _element + def __init__(self, _Element element not None, *, namespaces=None, + extensions=None, regexp=True, smart_strings=True): + cdef xpath.xmlXPathContext* xpathCtxt + cdef int ns_register_status + cdef _Document doc + _assertValidNode(element) + _assertValidDoc(element._doc) + self._element = element + doc = element._doc + _XPathEvaluatorBase.__init__(self, namespaces, extensions, + regexp, smart_strings) + xpathCtxt = xpath.xmlXPathNewContext(doc._c_doc) + if xpathCtxt is NULL: + raise MemoryError() + self.set_context(xpathCtxt) + + def register_namespace(self, prefix, uri): + """Register a namespace with the XPath context. + """ + assert self._xpathCtxt is not NULL, "XPath context not initialised" + self._context.addNamespace(prefix, uri) + + def register_namespaces(self, namespaces): + """Register a prefix -> uri dict. + """ + assert self._xpathCtxt is not NULL, "XPath context not initialised" + for prefix, uri in namespaces.items(): + self._context.addNamespace(prefix, uri) + + def __call__(self, _path, **_variables): + """__call__(self, _path, **_variables) + + Evaluate an XPath expression on the document. + + Variables may be provided as keyword arguments. Note that namespaces + are currently not supported for variables. + + Absolute XPath expressions (starting with '/') will be evaluated + against the ElementTree as returned by getroottree(). + """ + cdef xpath.xmlXPathObject* xpathObj + cdef _Document doc + assert self._xpathCtxt is not NULL, "XPath context not initialised" + path = _utf8(_path) + doc = self._element._doc + + self._lock() + self._xpathCtxt.node = self._element._c_node + try: + self._context.register_context(doc) + self._context.registerVariables(_variables) + c_path = _xcstr(path) + with nogil: + xpathObj = xpath.xmlXPathEvalExpression( + c_path, self._xpathCtxt) + result = self._handle_result(xpathObj, doc) + finally: + self._context.unregister_context() + self._unlock() + + return result + + +cdef class XPathDocumentEvaluator(XPathElementEvaluator): + """XPathDocumentEvaluator(self, etree, namespaces=None, extensions=None, regexp=True, smart_strings=True) + Create an XPath evaluator for an ElementTree. + + Additional namespace declarations can be passed with the + 'namespace' keyword argument. EXSLT regular expression support + can be disabled with the 'regexp' boolean keyword (defaults to + True). Smart strings will be returned for string results unless + you pass ``smart_strings=False``. + """ + def __init__(self, _ElementTree etree not None, *, namespaces=None, + extensions=None, regexp=True, smart_strings=True): + XPathElementEvaluator.__init__( + self, etree._context_node, namespaces=namespaces, + extensions=extensions, regexp=regexp, + smart_strings=smart_strings) + + def __call__(self, _path, **_variables): + """__call__(self, _path, **_variables) + + Evaluate an XPath expression on the document. + + Variables may be provided as keyword arguments. Note that namespaces + are currently not supported for variables. + """ + cdef xpath.xmlXPathObject* xpathObj + cdef xmlDoc* c_doc + cdef _Document doc + assert self._xpathCtxt is not NULL, "XPath context not initialised" + path = _utf8(_path) + doc = self._element._doc + + self._lock() + try: + self._context.register_context(doc) + c_doc = _fakeRootDoc(doc._c_doc, self._element._c_node) + try: + self._context.registerVariables(_variables) + c_path = _xcstr(path) + with nogil: + self._xpathCtxt.doc = c_doc + self._xpathCtxt.node = tree.xmlDocGetRootElement(c_doc) + xpathObj = xpath.xmlXPathEvalExpression( + c_path, self._xpathCtxt) + result = self._handle_result(xpathObj, doc) + finally: + _destroyFakeDoc(doc._c_doc, c_doc) + self._context.unregister_context() + finally: + self._unlock() + + return result + + +def XPathEvaluator(etree_or_element, *, namespaces=None, extensions=None, + regexp=True, smart_strings=True): + """XPathEvaluator(etree_or_element, namespaces=None, extensions=None, regexp=True, smart_strings=True) + + Creates an XPath evaluator for an ElementTree or an Element. + + The resulting object can be called with an XPath expression as argument + and XPath variables provided as keyword arguments. + + Additional namespace declarations can be passed with the + 'namespace' keyword argument. EXSLT regular expression support + can be disabled with the 'regexp' boolean keyword (defaults to + True). Smart strings will be returned for string results unless + you pass ``smart_strings=False``. + """ + if isinstance(etree_or_element, _ElementTree): + return XPathDocumentEvaluator( + etree_or_element, namespaces=namespaces, + extensions=extensions, regexp=regexp, smart_strings=smart_strings) + else: + return XPathElementEvaluator( + etree_or_element, namespaces=namespaces, + extensions=extensions, regexp=regexp, smart_strings=smart_strings) + + +cdef class XPath(_XPathEvaluatorBase): + """XPath(self, path, namespaces=None, extensions=None, regexp=True, smart_strings=True) + A compiled XPath expression that can be called on Elements and ElementTrees. + + Besides the XPath expression, you can pass prefix-namespace + mappings and extension functions to the constructor through the + keyword arguments ``namespaces`` and ``extensions``. EXSLT + regular expression support can be disabled with the 'regexp' + boolean keyword (defaults to True). Smart strings will be + returned for string results unless you pass + ``smart_strings=False``. + """ + cdef xpath.xmlXPathCompExpr* _xpath + cdef bytes _path + def __cinit__(self): + self._xpath = NULL + + def __init__(self, path, *, namespaces=None, extensions=None, + regexp=True, smart_strings=True): + cdef xpath.xmlXPathContext* xpathCtxt + _XPathEvaluatorBase.__init__(self, namespaces, extensions, + regexp, smart_strings) + self._path = _utf8(path) + xpathCtxt = xpath.xmlXPathNewContext(NULL) + if xpathCtxt is NULL: + raise MemoryError() + self.set_context(xpathCtxt) + self._xpath = xpath.xmlXPathCtxtCompile(xpathCtxt, _xcstr(self._path)) + if self._xpath is NULL: + raise self._build_parse_error() + + def __call__(self, _etree_or_element, **_variables): + "__call__(self, _etree_or_element, **_variables)" + cdef xpath.xmlXPathObject* xpathObj + cdef _Document document + cdef _Element element + + assert self._xpathCtxt is not NULL, "XPath context not initialised" + document = _documentOrRaise(_etree_or_element) + element = _rootNodeOrRaise(_etree_or_element) + + self._lock() + self._xpathCtxt.doc = document._c_doc + self._xpathCtxt.node = element._c_node + + try: + self._context.register_context(document) + self._context.registerVariables(_variables) + with nogil: + xpathObj = xpath.xmlXPathCompiledEval( + self._xpath, self._xpathCtxt) + result = self._handle_result(xpathObj, document) + finally: + self._context.unregister_context() + self._unlock() + return result + + @property + def path(self): + """The literal XPath expression. + """ + return self._path.decode('UTF-8') + + def __dealloc__(self): + if self._xpath is not NULL: + xpath.xmlXPathFreeCompExpr(self._xpath) + + def __repr__(self): + return self.path + + +cdef object _replace_strings = re.compile(b'("[^"]*")|(\'[^\']*\')').sub +cdef object _find_namespaces = re.compile(b'({[^}]+})').findall + +cdef class ETXPath(XPath): + """ETXPath(self, path, extensions=None, regexp=True, smart_strings=True) + Special XPath class that supports the ElementTree {uri} notation for namespaces. + + Note that this class does not accept the ``namespace`` keyword + argument. All namespaces must be passed as part of the path + string. Smart strings will be returned for string results unless + you pass ``smart_strings=False``. + """ + def __init__(self, path, *, extensions=None, regexp=True, + smart_strings=True): + path, namespaces = self._nsextract_path(path) + XPath.__init__(self, path, namespaces=namespaces, + extensions=extensions, regexp=regexp, + smart_strings=smart_strings) + + cdef _nsextract_path(self, path): + # replace {namespaces} by new prefixes + cdef dict namespaces = {} + cdef list namespace_defs = [] + cdef int i + path_utf = _utf8(path) + stripped_path = _replace_strings(b'', path_utf) # remove string literals + i = 1 + for namespace_def in _find_namespaces(stripped_path): + if namespace_def not in namespace_defs: + prefix = python.PyBytes_FromFormat("__xpp%02d", i) + i += 1 + namespace_defs.append(namespace_def) + namespace = namespace_def[1:-1] # remove '{}' + namespace = (namespace).decode('utf8') + namespaces[prefix.decode('utf8')] = namespace + prefix_str = prefix + b':' + # FIXME: this also replaces {namespaces} within strings! + path_utf = path_utf.replace(namespace_def, prefix_str) + path = path_utf.decode('utf8') + return path, namespaces diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xsltext.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xsltext.pxi new file mode 100644 index 0000000000000000000000000000000000000000..21894b9ef5859a455fd2f9f4443e805818b94517 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/lxml/xsltext.pxi @@ -0,0 +1,242 @@ +# XSLT extension elements + +cdef class XSLTExtension: + """Base class of an XSLT extension element. + """ + def execute(self, context, self_node, input_node, output_parent): + """execute(self, context, self_node, input_node, output_parent) + Execute this extension element. + + Subclasses must override this method. They may append + elements to the `output_parent` element here, or set its text + content. To this end, the `input_node` provides read-only + access to the current node in the input document, and the + `self_node` points to the extension element in the stylesheet. + + Note that the `output_parent` parameter may be `None` if there + is no parent element in the current context (e.g. no content + was added to the output tree yet). + """ + pass + + def apply_templates(self, _XSLTContext context not None, node, output_parent=None, + *, elements_only=False, remove_blank_text=False): + """apply_templates(self, context, node, output_parent=None, elements_only=False, remove_blank_text=False) + + Call this method to retrieve the result of applying templates + to an element. + + The return value is a list of elements or text strings that + were generated by the XSLT processor. If you pass + ``elements_only=True``, strings will be discarded from the result + list. The option ``remove_blank_text=True`` will only discard + strings that consist entirely of whitespace (e.g. formatting). + These options do not apply to Elements, only to bare string results. + + If you pass an Element as `output_parent` parameter, the result + will instead be appended to the element (including attributes + etc.) and the return value will be `None`. This is a safe way + to generate content into the output document directly, without + having to take care of special values like text or attributes. + Note that the string discarding options will be ignored in this + case. + """ + cdef xmlNode* c_parent + cdef xmlNode* c_node + cdef xmlNode* c_context_node + assert context._xsltCtxt is not NULL, "XSLT context not initialised" + c_context_node = _roNodeOf(node) + #assert c_context_node.doc is context._xsltContext.node.doc, \ + # "switching input documents during transformation is not currently supported" + + if output_parent is not None: + c_parent = _nonRoNodeOf(output_parent) + else: + c_parent = tree.xmlNewDocNode( + context._xsltCtxt.output, NULL, "fake-parent", NULL) + + c_node = context._xsltCtxt.insert + context._xsltCtxt.insert = c_parent + xslt.xsltProcessOneNode( + context._xsltCtxt, c_context_node, NULL) + context._xsltCtxt.insert = c_node + + if output_parent is not None: + return None + + try: + return self._collectXSLTResultContent( + context, c_parent, elements_only, remove_blank_text) + finally: + # free all intermediate nodes that will not be freed by proxies + tree.xmlFreeNode(c_parent) + + def process_children(self, _XSLTContext context not None, output_parent=None, + *, elements_only=False, remove_blank_text=False): + """process_children(self, context, output_parent=None, elements_only=False, remove_blank_text=False) + + Call this method to process the XSLT content of the extension + element itself. + + The return value is a list of elements or text strings that + were generated by the XSLT processor. If you pass + ``elements_only=True``, strings will be discarded from the result + list. The option ``remove_blank_text=True`` will only discard + strings that consist entirely of whitespace (e.g. formatting). + These options do not apply to Elements, only to bare string results. + + If you pass an Element as `output_parent` parameter, the result + will instead be appended to the element (including attributes + etc.) and the return value will be `None`. This is a safe way + to generate content into the output document directly, without + having to take care of special values like text or attributes. + Note that the string discarding options will be ignored in this + case. + """ + cdef xmlNode* c_parent + cdef xslt.xsltTransformContext* c_ctxt = context._xsltCtxt + cdef xmlNode* c_old_output_parent = c_ctxt.insert + assert context._xsltCtxt is not NULL, "XSLT context not initialised" + + # output_parent node is used for adding results instead of + # elements list used in apply_templates, that's easier and allows to + # use attributes added to extension element with . + + if output_parent is not None: + c_parent = _nonRoNodeOf(output_parent) + else: + c_parent = tree.xmlNewDocNode( + context._xsltCtxt.output, NULL, "fake-parent", NULL) + + c_ctxt.insert = c_parent + xslt.xsltApplyOneTemplate(c_ctxt, + c_ctxt.node, c_ctxt.inst.children, NULL, NULL) + c_ctxt.insert = c_old_output_parent + + if output_parent is not None: + return None + + try: + return self._collectXSLTResultContent( + context, c_parent, elements_only, remove_blank_text) + finally: + # free all intermediate nodes that will not be freed by proxies + tree.xmlFreeNode(c_parent) + + cdef _collectXSLTResultContent(self, _XSLTContext context, xmlNode* c_parent, + bint elements_only, bint remove_blank_text): + cdef xmlNode* c_node + cdef xmlNode* c_next + cdef _ReadOnlyProxy proxy + cdef list results = [] # or maybe _collectAttributes(c_parent, 2) ? + c_node = c_parent.children + while c_node is not NULL: + c_next = c_node.next + if c_node.type == tree.XML_TEXT_NODE: + if not elements_only: + s = funicode(c_node.content) + if not remove_blank_text or s.strip(): + results.append(s) + s = None + elif c_node.type == tree.XML_ELEMENT_NODE: + proxy = _newReadOnlyProxy( + context._extension_element_proxy, c_node) + results.append(proxy) + # unlink node and make sure it will be freed later on + tree.xmlUnlinkNode(c_node) + proxy.free_after_use() + else: + raise TypeError, \ + f"unsupported XSLT result type: {c_node.type}" + c_node = c_next + return results + + +cdef _registerXSLTExtensions(xslt.xsltTransformContext* c_ctxt, + extension_dict): + for ns_utf, name_utf in extension_dict: + xslt.xsltRegisterExtElement( + c_ctxt, _xcstr(name_utf), _xcstr(ns_utf), + _callExtensionElement) + +cdef void _callExtensionElement(xslt.xsltTransformContext* c_ctxt, + xmlNode* c_context_node, + xmlNode* c_inst_node, + void* dummy) noexcept with gil: + cdef _XSLTContext context + cdef XSLTExtension extension + cdef python.PyObject* dict_result + cdef xmlNode* c_node + cdef _ReadOnlyProxy context_node = None, self_node = None + cdef object output_parent # not restricted to ro-nodes + c_uri = _getNs(c_inst_node) + if c_uri is NULL: + # not allowed, and should never happen + return + if c_ctxt.xpathCtxt.userData is NULL: + # just for safety, should never happen + return + context = <_XSLTContext>c_ctxt.xpathCtxt.userData + try: + try: + dict_result = python.PyDict_GetItem( + context._extension_elements, (c_uri, c_inst_node.name)) + if dict_result is NULL: + raise KeyError, f"extension element {funicode(c_inst_node.name)} not found" + extension = dict_result + + try: + # build the context proxy nodes + self_node = _newReadOnlyProxy(None, c_inst_node) + if _isElement(c_ctxt.insert): + output_parent = _newAppendOnlyProxy(self_node, c_ctxt.insert) + else: + # may be the document node or other stuff + output_parent = _newOpaqueAppendOnlyNodeWrapper(c_ctxt.insert) + if c_context_node.type in (tree.XML_DOCUMENT_NODE, + tree.XML_HTML_DOCUMENT_NODE): + c_node = tree.xmlDocGetRootElement(c_context_node) + if c_node is not NULL: + context_node = _newReadOnlyProxy(self_node, c_node) + else: + context_node = None + elif c_context_node.type in (tree.XML_ATTRIBUTE_NODE, + tree.XML_TEXT_NODE, + tree.XML_CDATA_SECTION_NODE): + # this isn't easy to support using read-only + # nodes, as the smart-string factory must + # instantiate the parent proxy somehow... + raise TypeError(f"Unsupported element type: {c_context_node.type}") + else: + context_node = _newReadOnlyProxy(self_node, c_context_node) + + # run the XSLT extension + context._extension_element_proxy = self_node + extension.execute(context, self_node, context_node, output_parent) + finally: + context._extension_element_proxy = None + if self_node is not None: + _freeReadOnlyProxies(self_node) + except Exception as e: + try: + e = unicode(e).encode("UTF-8") + except: + e = repr(e).encode("UTF-8") + message = python.PyBytes_FromFormat( + "Error executing extension element '%s': %s", + c_inst_node.name, _cstr(e)) + xslt.xsltTransformError(c_ctxt, NULL, c_inst_node, "%s", message) + context._exc._store_raised() + except: + # just in case + message = python.PyBytes_FromFormat( + "Error executing extension element '%s'", c_inst_node.name) + xslt.xsltTransformError(c_ctxt, NULL, c_inst_node, "%s", message) + context._exc._store_raised() + except: + # no Python functions here - everything can fail... + xslt.xsltTransformError(c_ctxt, NULL, c_inst_node, + "Error during XSLT extension element evaluation") + context._exc._store_raised() + finally: + return # swallow any further exceptions diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..181ca3b0c6bee682ebce2ba52772f2ce1f9b4d30 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/__init__.py @@ -0,0 +1,975 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# When adding a new object to this init, remember to add it twice: once inside the `_import_structure` dictionary and +# once inside the `if TYPE_CHECKING` branch. The `TYPE_CHECKING` should have import statements as usual, but they are +# only there for type checking. The `_import_structure` is a dictionary submodule to list of object names, and is used +# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names +# in the namespace without actually importing anything (and especially none of the backends). + +__version__ = "4.57.6" + +from pathlib import Path +from typing import TYPE_CHECKING + +# Check the dependencies satisfy the minimal versions required. +from . import dependency_versions_check +from .utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_essentia_available, + is_g2p_en_available, + is_librosa_available, + is_mistral_common_available, + is_mlx_available, + is_pretty_midi_available, +) + +# Note: the following symbols are deliberately exported with `as` +# so that mypy, pylint or other static linters can recognize them, +# given that they are not exported using `__all__` in this file. +from .utils import is_bitsandbytes_available as is_bitsandbytes_available +from .utils import is_flax_available as is_flax_available +from .utils import is_keras_nlp_available as is_keras_nlp_available +from .utils import is_scipy_available as is_scipy_available +from .utils import is_sentencepiece_available as is_sentencepiece_available +from .utils import is_speech_available as is_speech_available +from .utils import is_tensorflow_text_available as is_tensorflow_text_available +from .utils import is_tf_available as is_tf_available +from .utils import is_timm_available as is_timm_available +from .utils import is_tokenizers_available as is_tokenizers_available +from .utils import is_torch_available as is_torch_available +from .utils import is_torchaudio_available as is_torchaudio_available +from .utils import is_torchvision_available as is_torchvision_available +from .utils import is_vision_available as is_vision_available +from .utils import logging as logging +from .utils.import_utils import define_import_structure + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Base objects, independent of any specific backend +_import_structure = { + "audio_utils": [], + "commands": [], + "configuration_utils": ["PretrainedConfig"], + "convert_graph_to_onnx": [], + "convert_slow_tokenizers_checkpoints_to_fast": [], + "convert_tf_hub_seq_to_seq_bert_to_pytorch": [], + "data": [ + "DataProcessor", + "InputExample", + "InputFeatures", + "SingleSentenceClassificationProcessor", + "SquadExample", + "SquadFeatures", + "SquadV1Processor", + "SquadV2Processor", + "glue_compute_metrics", + "glue_convert_examples_to_features", + "glue_output_modes", + "glue_processors", + "glue_tasks_num_labels", + "squad_convert_examples_to_features", + "xnli_compute_metrics", + "xnli_output_modes", + "xnli_processors", + "xnli_tasks_num_labels", + ], + "data.data_collator": [ + "DataCollator", + "DataCollatorForLanguageModeling", + "DataCollatorForMultipleChoice", + "DataCollatorForPermutationLanguageModeling", + "DataCollatorForSeq2Seq", + "DataCollatorForSOP", + "DataCollatorForTokenClassification", + "DataCollatorForWholeWordMask", + "DataCollatorWithFlattening", + "DataCollatorWithPadding", + "DefaultDataCollator", + "default_data_collator", + ], + "data.metrics": [], + "data.processors": [], + "debug_utils": [], + "dependency_versions_check": [], + "dependency_versions_table": [], + "dynamic_module_utils": [], + "feature_extraction_sequence_utils": ["SequenceFeatureExtractor"], + "feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"], + "file_utils": [], + "generation": [ + "AsyncTextIteratorStreamer", + "CompileConfig", + "GenerationConfig", + "TextIteratorStreamer", + "TextStreamer", + "WatermarkingConfig", + ], + "hf_argparser": ["HfArgumentParser"], + "hyperparameter_search": [], + "image_transforms": [], + "integrations": [ + "is_clearml_available", + "is_comet_available", + "is_dvclive_available", + "is_neptune_available", + "is_optuna_available", + "is_ray_available", + "is_ray_tune_available", + "is_sigopt_available", + "is_swanlab_available", + "is_tensorboard_available", + "is_trackio_available", + "is_wandb_available", + ], + "loss": [], + "modelcard": ["ModelCard"], + # Losses + "modeling_tf_pytorch_utils": [ + "convert_tf_weight_name_to_pt_weight_name", + "load_pytorch_checkpoint_in_tf2_model", + "load_pytorch_model_in_tf2_model", + "load_pytorch_weights_in_tf2_model", + "load_tf2_checkpoint_in_pytorch_model", + "load_tf2_model_in_pytorch_model", + "load_tf2_weights_in_pytorch_model", + ], + # Models + "onnx": [], + "pipelines": [ + "AudioClassificationPipeline", + "AutomaticSpeechRecognitionPipeline", + "CsvPipelineDataFormat", + "DepthEstimationPipeline", + "DocumentQuestionAnsweringPipeline", + "FeatureExtractionPipeline", + "FillMaskPipeline", + "ImageClassificationPipeline", + "ImageFeatureExtractionPipeline", + "ImageSegmentationPipeline", + "ImageTextToTextPipeline", + "ImageToImagePipeline", + "ImageToTextPipeline", + "JsonPipelineDataFormat", + "KeypointMatchingPipeline", + "MaskGenerationPipeline", + "NerPipeline", + "ObjectDetectionPipeline", + "PipedPipelineDataFormat", + "Pipeline", + "PipelineDataFormat", + "QuestionAnsweringPipeline", + "SummarizationPipeline", + "TableQuestionAnsweringPipeline", + "Text2TextGenerationPipeline", + "TextClassificationPipeline", + "TextGenerationPipeline", + "TextToAudioPipeline", + "TokenClassificationPipeline", + "TranslationPipeline", + "VideoClassificationPipeline", + "VisualQuestionAnsweringPipeline", + "ZeroShotAudioClassificationPipeline", + "ZeroShotClassificationPipeline", + "ZeroShotImageClassificationPipeline", + "ZeroShotObjectDetectionPipeline", + "pipeline", + ], + "processing_utils": ["ProcessorMixin"], + "quantizers": [], + "testing_utils": [], + "tokenization_utils": ["PreTrainedTokenizer"], + "tokenization_utils_base": [ + "AddedToken", + "BatchEncoding", + "CharSpan", + "PreTrainedTokenizerBase", + "SpecialTokensMixin", + "TokenSpan", + ], + "trainer_callback": [ + "DefaultFlowCallback", + "EarlyStoppingCallback", + "PrinterCallback", + "ProgressCallback", + "TrainerCallback", + "TrainerControl", + "TrainerState", + ], + "trainer_utils": [ + "EvalPrediction", + "IntervalStrategy", + "SchedulerType", + "enable_full_determinism", + "set_seed", + ], + "training_args": ["TrainingArguments"], + "training_args_seq2seq": ["Seq2SeqTrainingArguments"], + "training_args_tf": ["TFTrainingArguments"], + "utils": [ + "CONFIG_NAME", + "MODEL_CARD_NAME", + "PYTORCH_PRETRAINED_BERT_CACHE", + "PYTORCH_TRANSFORMERS_CACHE", + "SPIECE_UNDERLINE", + "TF2_WEIGHTS_NAME", + "TF_WEIGHTS_NAME", + "TRANSFORMERS_CACHE", + "WEIGHTS_NAME", + "TensorType", + "add_end_docstrings", + "add_start_docstrings", + "is_apex_available", + "is_av_available", + "is_bitsandbytes_available", + "is_datasets_available", + "is_faiss_available", + "is_flax_available", + "is_keras_nlp_available", + "is_matplotlib_available", + "is_mlx_available", + "is_phonemizer_available", + "is_psutil_available", + "is_py3nvml_available", + "is_pyctcdecode_available", + "is_sacremoses_available", + "is_safetensors_available", + "is_scipy_available", + "is_sentencepiece_available", + "is_sklearn_available", + "is_speech_available", + "is_tensorflow_text_available", + "is_tf_available", + "is_timm_available", + "is_tokenizers_available", + "is_torch_available", + "is_torch_hpu_available", + "is_torch_mlu_available", + "is_torch_musa_available", + "is_torch_neuroncore_available", + "is_torch_npu_available", + "is_torchvision_available", + "is_torch_xla_available", + "is_torch_xpu_available", + "is_vision_available", + "logging", + ], + "utils.quantization_config": [ + "AqlmConfig", + "AutoRoundConfig", + "AwqConfig", + "BitNetQuantConfig", + "BitsAndBytesConfig", + "CompressedTensorsConfig", + "EetqConfig", + "FbgemmFp8Config", + "FineGrainedFP8Config", + "GPTQConfig", + "HiggsConfig", + "HqqConfig", + "Mxfp4Config", + "QuantoConfig", + "QuarkConfig", + "FPQuantConfig", + "SpQRConfig", + "TorchAoConfig", + "VptqConfig", + ], + "video_utils": [], +} + +# tokenizers-backed objects +try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_tokenizers_objects + + _import_structure["utils.dummy_tokenizers_objects"] = [ + name for name in dir(dummy_tokenizers_objects) if not name.startswith("_") + ] +else: + # Fast tokenizers structure + _import_structure["tokenization_utils_fast"] = ["PreTrainedTokenizerFast"] + + +try: + if not (is_sentencepiece_available() and is_tokenizers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_sentencepiece_and_tokenizers_objects + + _import_structure["utils.dummy_sentencepiece_and_tokenizers_objects"] = [ + name for name in dir(dummy_sentencepiece_and_tokenizers_objects) if not name.startswith("_") + ] +else: + _import_structure["convert_slow_tokenizer"] = [ + "SLOW_TO_FAST_CONVERTERS", + "convert_slow_tokenizer", + ] + +try: + if not (is_mistral_common_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_mistral_common_objects + + _import_structure["utils.dummy_mistral_common_objects"] = [ + name for name in dir(dummy_mistral_common_objects) if not name.startswith("_") + ] +else: + _import_structure["tokenization_mistral_common"] = ["MistralCommonTokenizer"] + +# Vision-specific objects +try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_vision_objects + + _import_structure["utils.dummy_vision_objects"] = [ + name for name in dir(dummy_vision_objects) if not name.startswith("_") + ] +else: + _import_structure["image_processing_base"] = ["ImageProcessingMixin"] + _import_structure["image_processing_utils"] = ["BaseImageProcessor"] + _import_structure["image_utils"] = ["ImageFeatureExtractionMixin"] + +try: + if not is_torchvision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_torchvision_objects + + _import_structure["utils.dummy_torchvision_objects"] = [ + name for name in dir(dummy_torchvision_objects) if not name.startswith("_") + ] +else: + _import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"] + _import_structure["video_processing_utils"] = ["BaseVideoProcessor"] + +# PyTorch-backed objects +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_pt_objects + + _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] +else: + _import_structure["model_debugging_utils"] = [ + "model_addition_debugger_context", + ] + _import_structure["activations"] = [] + _import_structure["cache_utils"] = [ + "CacheLayerMixin", + "DynamicLayer", + "StaticLayer", + "StaticSlidingWindowLayer", + "SlidingWindowLayer", + "ChunkedSlidingLayer", + "QuantoQuantizedLayer", + "HQQQuantizedLayer", + "Cache", + "DynamicCache", + "EncoderDecoderCache", + "HQQQuantizedCache", + "HybridCache", + "HybridChunkedCache", + "OffloadedCache", + "OffloadedStaticCache", + "QuantizedCache", + "QuantoQuantizedCache", + "SinkCache", + "SlidingWindowCache", + "StaticCache", + ] + _import_structure["data.datasets"] = [ + "GlueDataset", + "GlueDataTrainingArguments", + "LineByLineTextDataset", + "LineByLineWithRefDataset", + "LineByLineWithSOPTextDataset", + "SquadDataset", + "SquadDataTrainingArguments", + "TextDataset", + "TextDatasetForNextSentencePrediction", + ] + _import_structure["generation"].extend( + [ + "AlternatingCodebooksLogitsProcessor", + "BayesianDetectorConfig", + "BayesianDetectorModel", + "BeamScorer", + "ClassifierFreeGuidanceLogitsProcessor", + "ConstrainedBeamSearchScorer", + "Constraint", + "ConstraintListState", + "DisjunctiveConstraint", + "EncoderNoRepeatNGramLogitsProcessor", + "EncoderRepetitionPenaltyLogitsProcessor", + "EosTokenCriteria", + "EpsilonLogitsWarper", + "EtaLogitsWarper", + "ExponentialDecayLengthPenalty", + "ForcedBOSTokenLogitsProcessor", + "ForcedEOSTokenLogitsProcessor", + "GenerationMixin", + "InfNanRemoveLogitsProcessor", + "LogitNormalization", + "LogitsProcessor", + "LogitsProcessorList", + "MaxLengthCriteria", + "MaxTimeCriteria", + "MinLengthLogitsProcessor", + "MinNewTokensLengthLogitsProcessor", + "MinPLogitsWarper", + "NoBadWordsLogitsProcessor", + "NoRepeatNGramLogitsProcessor", + "PhrasalConstraint", + "PrefixConstrainedLogitsProcessor", + "RepetitionPenaltyLogitsProcessor", + "SequenceBiasLogitsProcessor", + "StoppingCriteria", + "StoppingCriteriaList", + "StopStringCriteria", + "SuppressTokensAtBeginLogitsProcessor", + "SuppressTokensLogitsProcessor", + "SynthIDTextWatermarkDetector", + "SynthIDTextWatermarkingConfig", + "SynthIDTextWatermarkLogitsProcessor", + "TemperatureLogitsWarper", + "TopKLogitsWarper", + "TopPLogitsWarper", + "TypicalLogitsWarper", + "UnbatchedClassifierFreeGuidanceLogitsProcessor", + "WatermarkDetector", + "WatermarkLogitsProcessor", + "WhisperTimeStampLogitsProcessor", + ] + ) + + # PyTorch domain libraries integration + _import_structure["integrations.executorch"] = [ + "TorchExportableModuleWithStaticCache", + "convert_and_export_with_cache", + ] + + _import_structure["modeling_flash_attention_utils"] = [] + _import_structure["modeling_layers"] = ["GradientCheckpointingLayer"] + _import_structure["modeling_outputs"] = [] + _import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "dynamic_rope_update"] + _import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"] + _import_structure["masking_utils"] = ["AttentionMaskInterface"] + _import_structure["optimization"] = [ + "Adafactor", + "get_constant_schedule", + "get_constant_schedule_with_warmup", + "get_cosine_schedule_with_warmup", + "get_cosine_with_hard_restarts_schedule_with_warmup", + "get_cosine_with_min_lr_schedule_with_warmup", + "get_cosine_with_min_lr_schedule_with_warmup_lr_rate", + "get_inverse_sqrt_schedule", + "get_linear_schedule_with_warmup", + "get_polynomial_decay_schedule_with_warmup", + "get_scheduler", + "get_wsd_schedule", + "get_reduce_on_plateau_schedule", + ] + _import_structure["pytorch_utils"] = [ + "Conv1D", + "apply_chunking_to_forward", + "prune_layer", + "infer_device", + ] + _import_structure["sagemaker"] = [] + _import_structure["time_series_utils"] = [] + _import_structure["trainer"] = ["Trainer"] + _import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"] + _import_structure["trainer_seq2seq"] = ["Seq2SeqTrainer"] + +# TensorFlow-backed objects +try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_tf_objects + + _import_structure["utils.dummy_tf_objects"] = [name for name in dir(dummy_tf_objects) if not name.startswith("_")] +else: + _import_structure["activations_tf"] = [] + _import_structure["generation"].extend( + [ + "TFForcedBOSTokenLogitsProcessor", + "TFForcedEOSTokenLogitsProcessor", + "TFForceTokensLogitsProcessor", + "TFGenerationMixin", + "TFLogitsProcessor", + "TFLogitsProcessorList", + "TFLogitsWarper", + "TFMinLengthLogitsProcessor", + "TFNoBadWordsLogitsProcessor", + "TFNoRepeatNGramLogitsProcessor", + "TFRepetitionPenaltyLogitsProcessor", + "TFSuppressTokensAtBeginLogitsProcessor", + "TFSuppressTokensLogitsProcessor", + "TFTemperatureLogitsWarper", + "TFTopKLogitsWarper", + "TFTopPLogitsWarper", + ] + ) + _import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"] + _import_structure["modeling_tf_outputs"] = [] + _import_structure["modeling_tf_utils"] = [ + "TFPreTrainedModel", + "TFSequenceSummary", + "TFSharedEmbeddings", + "shape_list", + ] + _import_structure["optimization_tf"] = [ + "AdamWeightDecay", + "GradientAccumulator", + "WarmUp", + "create_optimizer", + ] + _import_structure["tf_utils"] = [] + + +# FLAX-backed objects +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_flax_objects + + _import_structure["utils.dummy_flax_objects"] = [ + name for name in dir(dummy_flax_objects) if not name.startswith("_") + ] +else: + _import_structure["generation"].extend( + [ + "FlaxForcedBOSTokenLogitsProcessor", + "FlaxForcedEOSTokenLogitsProcessor", + "FlaxForceTokensLogitsProcessor", + "FlaxGenerationMixin", + "FlaxLogitsProcessor", + "FlaxLogitsProcessorList", + "FlaxLogitsWarper", + "FlaxMinLengthLogitsProcessor", + "FlaxTemperatureLogitsWarper", + "FlaxSuppressTokensAtBeginLogitsProcessor", + "FlaxSuppressTokensLogitsProcessor", + "FlaxTopKLogitsWarper", + "FlaxTopPLogitsWarper", + "FlaxWhisperTimeStampLogitsProcessor", + ] + ) + _import_structure["modeling_flax_outputs"] = [] + _import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"] + +# Direct imports for type-checking +if TYPE_CHECKING: + # All modeling imports + from .cache_utils import Cache as Cache + from .cache_utils import ChunkedSlidingLayer as ChunkedSlidingLayer + from .cache_utils import DynamicCache as DynamicCache + from .cache_utils import DynamicLayer as DynamicLayer + from .cache_utils import EncoderDecoderCache as EncoderDecoderCache + from .cache_utils import HQQQuantizedCache as HQQQuantizedCache + from .cache_utils import HQQQuantizedLayer as HQQQuantizedLayer + from .cache_utils import HybridCache as HybridCache + from .cache_utils import OffloadedCache as OffloadedCache + from .cache_utils import OffloadedStaticCache as OffloadedStaticCache + from .cache_utils import QuantizedCache as QuantizedCache + from .cache_utils import QuantoQuantizedCache as QuantoQuantizedCache + from .cache_utils import QuantoQuantizedLayer as QuantoQuantizedLayer + from .cache_utils import SinkCache as SinkCache + from .cache_utils import SlidingWindowCache as SlidingWindowCache + from .cache_utils import SlidingWindowLayer as SlidingWindowLayer + from .cache_utils import StaticCache as StaticCache + from .cache_utils import StaticLayer as StaticLayer + from .cache_utils import StaticSlidingWindowLayer as StaticSlidingWindowLayer + from .configuration_utils import PretrainedConfig as PretrainedConfig + from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS as SLOW_TO_FAST_CONVERTERS + from .convert_slow_tokenizer import convert_slow_tokenizer as convert_slow_tokenizer + + # Data + from .data import DataProcessor as DataProcessor + from .data import InputExample as InputExample + from .data import InputFeatures as InputFeatures + from .data import SingleSentenceClassificationProcessor as SingleSentenceClassificationProcessor + from .data import SquadExample as SquadExample + from .data import SquadFeatures as SquadFeatures + from .data import SquadV1Processor as SquadV1Processor + from .data import SquadV2Processor as SquadV2Processor + from .data import glue_compute_metrics as glue_compute_metrics + from .data import glue_convert_examples_to_features as glue_convert_examples_to_features + from .data import glue_output_modes as glue_output_modes + from .data import glue_processors as glue_processors + from .data import glue_tasks_num_labels as glue_tasks_num_labels + from .data import squad_convert_examples_to_features as squad_convert_examples_to_features + from .data import xnli_compute_metrics as xnli_compute_metrics + from .data import xnli_output_modes as xnli_output_modes + from .data import xnli_processors as xnli_processors + from .data import xnli_tasks_num_labels as xnli_tasks_num_labels + from .data.data_collator import DataCollator as DataCollator + from .data.data_collator import DataCollatorForLanguageModeling as DataCollatorForLanguageModeling + from .data.data_collator import DataCollatorForMultipleChoice as DataCollatorForMultipleChoice + from .data.data_collator import ( + DataCollatorForPermutationLanguageModeling as DataCollatorForPermutationLanguageModeling, + ) + from .data.data_collator import DataCollatorForSeq2Seq as DataCollatorForSeq2Seq + from .data.data_collator import DataCollatorForSOP as DataCollatorForSOP + from .data.data_collator import DataCollatorForTokenClassification as DataCollatorForTokenClassification + from .data.data_collator import DataCollatorForWholeWordMask as DataCollatorForWholeWordMask + from .data.data_collator import DataCollatorWithFlattening as DataCollatorWithFlattening + from .data.data_collator import DataCollatorWithPadding as DataCollatorWithPadding + from .data.data_collator import DefaultDataCollator as DefaultDataCollator + from .data.data_collator import default_data_collator as default_data_collator + from .data.datasets import GlueDataset as GlueDataset + from .data.datasets import GlueDataTrainingArguments as GlueDataTrainingArguments + from .data.datasets import LineByLineTextDataset as LineByLineTextDataset + from .data.datasets import LineByLineWithRefDataset as LineByLineWithRefDataset + from .data.datasets import LineByLineWithSOPTextDataset as LineByLineWithSOPTextDataset + from .data.datasets import SquadDataset as SquadDataset + from .data.datasets import SquadDataTrainingArguments as SquadDataTrainingArguments + from .data.datasets import TextDataset as TextDataset + from .data.datasets import TextDatasetForNextSentencePrediction as TextDatasetForNextSentencePrediction + from .feature_extraction_sequence_utils import SequenceFeatureExtractor as SequenceFeatureExtractor + + # Feature Extractor + from .feature_extraction_utils import BatchFeature as BatchFeature + from .feature_extraction_utils import FeatureExtractionMixin as FeatureExtractionMixin + + # Generation + from .generation import AlternatingCodebooksLogitsProcessor as AlternatingCodebooksLogitsProcessor + from .generation import AsyncTextIteratorStreamer as AsyncTextIteratorStreamer + from .generation import BayesianDetectorConfig as BayesianDetectorConfig + from .generation import BayesianDetectorModel as BayesianDetectorModel + from .generation import BeamScorer as BeamScorer + from .generation import ClassifierFreeGuidanceLogitsProcessor as ClassifierFreeGuidanceLogitsProcessor + from .generation import CompileConfig as CompileConfig + from .generation import ConstrainedBeamSearchScorer as ConstrainedBeamSearchScorer + from .generation import Constraint as Constraint + from .generation import ConstraintListState as ConstraintListState + from .generation import DisjunctiveConstraint as DisjunctiveConstraint + from .generation import EncoderNoRepeatNGramLogitsProcessor as EncoderNoRepeatNGramLogitsProcessor + from .generation import EncoderRepetitionPenaltyLogitsProcessor as EncoderRepetitionPenaltyLogitsProcessor + from .generation import EosTokenCriteria as EosTokenCriteria + from .generation import EpsilonLogitsWarper as EpsilonLogitsWarper + from .generation import EtaLogitsWarper as EtaLogitsWarper + from .generation import ExponentialDecayLengthPenalty as ExponentialDecayLengthPenalty + from .generation import FlaxForcedBOSTokenLogitsProcessor as FlaxForcedBOSTokenLogitsProcessor + from .generation import FlaxForcedEOSTokenLogitsProcessor as FlaxForcedEOSTokenLogitsProcessor + from .generation import FlaxForceTokensLogitsProcessor as FlaxForceTokensLogitsProcessor + from .generation import FlaxGenerationMixin as FlaxGenerationMixin + from .generation import FlaxLogitsProcessor as FlaxLogitsProcessor + from .generation import FlaxLogitsProcessorList as FlaxLogitsProcessorList + from .generation import FlaxLogitsWarper as FlaxLogitsWarper + from .generation import FlaxMinLengthLogitsProcessor as FlaxMinLengthLogitsProcessor + from .generation import FlaxSuppressTokensAtBeginLogitsProcessor as FlaxSuppressTokensAtBeginLogitsProcessor + from .generation import FlaxSuppressTokensLogitsProcessor as FlaxSuppressTokensLogitsProcessor + from .generation import FlaxTemperatureLogitsWarper as FlaxTemperatureLogitsWarper + from .generation import FlaxTopKLogitsWarper as FlaxTopKLogitsWarper + from .generation import FlaxTopPLogitsWarper as FlaxTopPLogitsWarper + from .generation import FlaxWhisperTimeStampLogitsProcessor as FlaxWhisperTimeStampLogitsProcessor + from .generation import ForcedBOSTokenLogitsProcessor as ForcedBOSTokenLogitsProcessor + from .generation import ForcedEOSTokenLogitsProcessor as ForcedEOSTokenLogitsProcessor + from .generation import GenerationConfig as GenerationConfig + from .generation import GenerationMixin as GenerationMixin + from .generation import InfNanRemoveLogitsProcessor as InfNanRemoveLogitsProcessor + from .generation import LogitNormalization as LogitNormalization + from .generation import LogitsProcessor as LogitsProcessor + from .generation import LogitsProcessorList as LogitsProcessorList + from .generation import MaxLengthCriteria as MaxLengthCriteria + from .generation import MaxTimeCriteria as MaxTimeCriteria + from .generation import MinLengthLogitsProcessor as MinLengthLogitsProcessor + from .generation import MinNewTokensLengthLogitsProcessor as MinNewTokensLengthLogitsProcessor + from .generation import MinPLogitsWarper as MinPLogitsWarper + from .generation import NoBadWordsLogitsProcessor as NoBadWordsLogitsProcessor + from .generation import NoRepeatNGramLogitsProcessor as NoRepeatNGramLogitsProcessor + from .generation import PhrasalConstraint as PhrasalConstraint + from .generation import PrefixConstrainedLogitsProcessor as PrefixConstrainedLogitsProcessor + from .generation import RepetitionPenaltyLogitsProcessor as RepetitionPenaltyLogitsProcessor + from .generation import SequenceBiasLogitsProcessor as SequenceBiasLogitsProcessor + from .generation import StoppingCriteria as StoppingCriteria + from .generation import StoppingCriteriaList as StoppingCriteriaList + from .generation import StopStringCriteria as StopStringCriteria + from .generation import SuppressTokensAtBeginLogitsProcessor as SuppressTokensAtBeginLogitsProcessor + from .generation import SuppressTokensLogitsProcessor as SuppressTokensLogitsProcessor + from .generation import SynthIDTextWatermarkDetector as SynthIDTextWatermarkDetector + from .generation import SynthIDTextWatermarkingConfig as SynthIDTextWatermarkingConfig + from .generation import SynthIDTextWatermarkLogitsProcessor as SynthIDTextWatermarkLogitsProcessor + from .generation import TemperatureLogitsWarper as TemperatureLogitsWarper + from .generation import TextIteratorStreamer as TextIteratorStreamer + from .generation import TextStreamer as TextStreamer + from .generation import TFForcedBOSTokenLogitsProcessor as TFForcedBOSTokenLogitsProcessor + from .generation import TFForcedEOSTokenLogitsProcessor as TFForcedEOSTokenLogitsProcessor + from .generation import TFForceTokensLogitsProcessor as TFForceTokensLogitsProcessor + from .generation import TFGenerationMixin as TFGenerationMixin + from .generation import TFLogitsProcessor as TFLogitsProcessor + from .generation import TFLogitsProcessorList as TFLogitsProcessorList + from .generation import TFLogitsWarper as TFLogitsWarper + from .generation import TFMinLengthLogitsProcessor as TFMinLengthLogitsProcessor + from .generation import TFNoBadWordsLogitsProcessor as TFNoBadWordsLogitsProcessor + from .generation import TFNoRepeatNGramLogitsProcessor as TFNoRepeatNGramLogitsProcessor + from .generation import TFRepetitionPenaltyLogitsProcessor as TFRepetitionPenaltyLogitsProcessor + from .generation import TFSuppressTokensAtBeginLogitsProcessor as TFSuppressTokensAtBeginLogitsProcessor + from .generation import TFSuppressTokensLogitsProcessor as TFSuppressTokensLogitsProcessor + from .generation import TFTemperatureLogitsWarper as TFTemperatureLogitsWarper + from .generation import TFTopKLogitsWarper as TFTopKLogitsWarper + from .generation import TFTopPLogitsWarper as TFTopPLogitsWarper + from .generation import TopKLogitsWarper as TopKLogitsWarper + from .generation import TopPLogitsWarper as TopPLogitsWarper + from .generation import TypicalLogitsWarper as TypicalLogitsWarper + from .generation import ( + UnbatchedClassifierFreeGuidanceLogitsProcessor as UnbatchedClassifierFreeGuidanceLogitsProcessor, + ) + from .generation import WatermarkDetector as WatermarkDetector + from .generation import WatermarkingConfig as WatermarkingConfig + from .generation import WatermarkLogitsProcessor as WatermarkLogitsProcessor + from .generation import WhisperTimeStampLogitsProcessor as WhisperTimeStampLogitsProcessor + from .hf_argparser import HfArgumentParser as HfArgumentParser + from .image_processing_base import ImageProcessingMixin as ImageProcessingMixin + from .image_processing_utils import BaseImageProcessor as BaseImageProcessor + from .image_processing_utils_fast import BaseImageProcessorFast as BaseImageProcessorFast + from .image_utils import ImageFeatureExtractionMixin as ImageFeatureExtractionMixin + + # Integrations + from .integrations import is_clearml_available as is_clearml_available + from .integrations import is_comet_available as is_comet_available + from .integrations import is_dvclive_available as is_dvclive_available + from .integrations import is_neptune_available as is_neptune_available + from .integrations import is_optuna_available as is_optuna_available + from .integrations import is_ray_available as is_ray_available + from .integrations import is_ray_tune_available as is_ray_tune_available + from .integrations import is_sigopt_available as is_sigopt_available + from .integrations import is_swanlab_available as is_swanlab_available + from .integrations import is_tensorboard_available as is_tensorboard_available + from .integrations import is_trackio_available as is_trackio_available + from .integrations import is_wandb_available as is_wandb_available + from .integrations.executorch import TorchExportableModuleWithStaticCache as TorchExportableModuleWithStaticCache + from .integrations.executorch import convert_and_export_with_cache as convert_and_export_with_cache + from .keras_callbacks import KerasMetricCallback as KerasMetricCallback + from .keras_callbacks import PushToHubCallback as PushToHubCallback + from .masking_utils import AttentionMaskInterface as AttentionMaskInterface + from .model_debugging_utils import model_addition_debugger_context as model_addition_debugger_context + + # Model Cards + from .modelcard import ModelCard as ModelCard + from .modeling_flax_utils import FlaxPreTrainedModel as FlaxPreTrainedModel + from .modeling_layers import GradientCheckpointingLayer as GradientCheckpointingLayer + from .modeling_rope_utils import ROPE_INIT_FUNCTIONS as ROPE_INIT_FUNCTIONS + from .modeling_rope_utils import dynamic_rope_update as dynamic_rope_update + + # TF 2.0 <=> PyTorch conversion utilities + from .modeling_tf_pytorch_utils import ( + convert_tf_weight_name_to_pt_weight_name as convert_tf_weight_name_to_pt_weight_name, + ) + from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model as load_pytorch_checkpoint_in_tf2_model + from .modeling_tf_pytorch_utils import load_pytorch_model_in_tf2_model as load_pytorch_model_in_tf2_model + from .modeling_tf_pytorch_utils import load_pytorch_weights_in_tf2_model as load_pytorch_weights_in_tf2_model + from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model as load_tf2_checkpoint_in_pytorch_model + from .modeling_tf_pytorch_utils import load_tf2_model_in_pytorch_model as load_tf2_model_in_pytorch_model + from .modeling_tf_pytorch_utils import load_tf2_weights_in_pytorch_model as load_tf2_weights_in_pytorch_model + from .modeling_tf_utils import TFPreTrainedModel as TFPreTrainedModel + from .modeling_tf_utils import TFSequenceSummary as TFSequenceSummary + from .modeling_tf_utils import TFSharedEmbeddings as TFSharedEmbeddings + from .modeling_tf_utils import shape_list as shape_list + from .modeling_utils import AttentionInterface as AttentionInterface + from .modeling_utils import PreTrainedModel as PreTrainedModel + from .models import * + from .models.mamba.modeling_mamba import MambaCache as MambaCache + from .models.timm_wrapper import TimmWrapperImageProcessor as TimmWrapperImageProcessor + + # Optimization + from .optimization import Adafactor as Adafactor + from .optimization import get_constant_schedule as get_constant_schedule + from .optimization import get_constant_schedule_with_warmup as get_constant_schedule_with_warmup + from .optimization import get_cosine_schedule_with_warmup as get_cosine_schedule_with_warmup + from .optimization import ( + get_cosine_with_hard_restarts_schedule_with_warmup as get_cosine_with_hard_restarts_schedule_with_warmup, + ) + from .optimization import ( + get_cosine_with_min_lr_schedule_with_warmup as get_cosine_with_min_lr_schedule_with_warmup, + ) + from .optimization import ( + get_cosine_with_min_lr_schedule_with_warmup_lr_rate as get_cosine_with_min_lr_schedule_with_warmup_lr_rate, + ) + from .optimization import get_inverse_sqrt_schedule as get_inverse_sqrt_schedule + from .optimization import get_linear_schedule_with_warmup as get_linear_schedule_with_warmup + from .optimization import get_polynomial_decay_schedule_with_warmup as get_polynomial_decay_schedule_with_warmup + from .optimization import get_scheduler as get_scheduler + from .optimization import get_wsd_schedule as get_wsd_schedule + + # Optimization + from .optimization_tf import AdamWeightDecay as AdamWeightDecay + from .optimization_tf import GradientAccumulator as GradientAccumulator + from .optimization_tf import WarmUp as WarmUp + from .optimization_tf import create_optimizer as create_optimizer + + # Pipelines + from .pipelines import AudioClassificationPipeline as AudioClassificationPipeline + from .pipelines import AutomaticSpeechRecognitionPipeline as AutomaticSpeechRecognitionPipeline + from .pipelines import CsvPipelineDataFormat as CsvPipelineDataFormat + from .pipelines import DepthEstimationPipeline as DepthEstimationPipeline + from .pipelines import DocumentQuestionAnsweringPipeline as DocumentQuestionAnsweringPipeline + from .pipelines import FeatureExtractionPipeline as FeatureExtractionPipeline + from .pipelines import FillMaskPipeline as FillMaskPipeline + from .pipelines import ImageClassificationPipeline as ImageClassificationPipeline + from .pipelines import ImageFeatureExtractionPipeline as ImageFeatureExtractionPipeline + from .pipelines import ImageSegmentationPipeline as ImageSegmentationPipeline + from .pipelines import ImageTextToTextPipeline as ImageTextToTextPipeline + from .pipelines import ImageToImagePipeline as ImageToImagePipeline + from .pipelines import ImageToTextPipeline as ImageToTextPipeline + from .pipelines import JsonPipelineDataFormat as JsonPipelineDataFormat + from .pipelines import KeypointMatchingPipeline as KeypointMatchingPipeline + from .pipelines import MaskGenerationPipeline as MaskGenerationPipeline + from .pipelines import NerPipeline as NerPipeline + from .pipelines import ObjectDetectionPipeline as ObjectDetectionPipeline + from .pipelines import PipedPipelineDataFormat as PipedPipelineDataFormat + from .pipelines import Pipeline as Pipeline + from .pipelines import PipelineDataFormat as PipelineDataFormat + from .pipelines import QuestionAnsweringPipeline as QuestionAnsweringPipeline + from .pipelines import SummarizationPipeline as SummarizationPipeline + from .pipelines import TableQuestionAnsweringPipeline as TableQuestionAnsweringPipeline + from .pipelines import Text2TextGenerationPipeline as Text2TextGenerationPipeline + from .pipelines import TextClassificationPipeline as TextClassificationPipeline + from .pipelines import TextGenerationPipeline as TextGenerationPipeline + from .pipelines import TextToAudioPipeline as TextToAudioPipeline + from .pipelines import TokenClassificationPipeline as TokenClassificationPipeline + from .pipelines import TranslationPipeline as TranslationPipeline + from .pipelines import VideoClassificationPipeline as VideoClassificationPipeline + from .pipelines import VisualQuestionAnsweringPipeline as VisualQuestionAnsweringPipeline + from .pipelines import ZeroShotAudioClassificationPipeline as ZeroShotAudioClassificationPipeline + from .pipelines import ZeroShotClassificationPipeline as ZeroShotClassificationPipeline + from .pipelines import ZeroShotImageClassificationPipeline as ZeroShotImageClassificationPipeline + from .pipelines import ZeroShotObjectDetectionPipeline as ZeroShotObjectDetectionPipeline + from .pipelines import pipeline as pipeline + from .processing_utils import ProcessorMixin as ProcessorMixin + from .pytorch_utils import Conv1D as Conv1D + from .pytorch_utils import apply_chunking_to_forward as apply_chunking_to_forward + from .pytorch_utils import prune_layer as prune_layer + + # Tokenization + from .tokenization_utils import PreTrainedTokenizer as PreTrainedTokenizer + from .tokenization_utils_base import AddedToken as AddedToken + from .tokenization_utils_base import BatchEncoding as BatchEncoding + from .tokenization_utils_base import CharSpan as CharSpan + from .tokenization_utils_base import PreTrainedTokenizerBase as PreTrainedTokenizerBase + from .tokenization_utils_base import SpecialTokensMixin as SpecialTokensMixin + from .tokenization_utils_base import TokenSpan as TokenSpan + from .tokenization_utils_fast import PreTrainedTokenizerFast as PreTrainedTokenizerFast + + # Trainer + from .trainer import Trainer as Trainer + + # Trainer + from .trainer_callback import DefaultFlowCallback as DefaultFlowCallback + from .trainer_callback import EarlyStoppingCallback as EarlyStoppingCallback + from .trainer_callback import PrinterCallback as PrinterCallback + from .trainer_callback import ProgressCallback as ProgressCallback + from .trainer_callback import TrainerCallback as TrainerCallback + from .trainer_callback import TrainerControl as TrainerControl + from .trainer_callback import TrainerState as TrainerState + from .trainer_pt_utils import torch_distributed_zero_first as torch_distributed_zero_first + from .trainer_seq2seq import Seq2SeqTrainer as Seq2SeqTrainer + from .trainer_utils import EvalPrediction as EvalPrediction + from .trainer_utils import IntervalStrategy as IntervalStrategy + from .trainer_utils import SchedulerType as SchedulerType + from .trainer_utils import enable_full_determinism as enable_full_determinism + from .trainer_utils import set_seed as set_seed + from .training_args import TrainingArguments as TrainingArguments + from .training_args_seq2seq import Seq2SeqTrainingArguments as Seq2SeqTrainingArguments + from .training_args_tf import TFTrainingArguments as TFTrainingArguments + + # Files and general utilities + from .utils import CONFIG_NAME as CONFIG_NAME + from .utils import MODEL_CARD_NAME as MODEL_CARD_NAME + from .utils import PYTORCH_PRETRAINED_BERT_CACHE as PYTORCH_PRETRAINED_BERT_CACHE + from .utils import PYTORCH_TRANSFORMERS_CACHE as PYTORCH_TRANSFORMERS_CACHE + from .utils import SPIECE_UNDERLINE as SPIECE_UNDERLINE + from .utils import TF2_WEIGHTS_NAME as TF2_WEIGHTS_NAME + from .utils import TF_WEIGHTS_NAME as TF_WEIGHTS_NAME + from .utils import TRANSFORMERS_CACHE as TRANSFORMERS_CACHE + from .utils import WEIGHTS_NAME as WEIGHTS_NAME + from .utils import TensorType as TensorType + from .utils import add_end_docstrings as add_end_docstrings + from .utils import add_start_docstrings as add_start_docstrings + from .utils import is_apex_available as is_apex_available + from .utils import is_av_available as is_av_available + from .utils import is_datasets_available as is_datasets_available + from .utils import is_faiss_available as is_faiss_available + from .utils import is_matplotlib_available as is_matplotlib_available + from .utils import is_phonemizer_available as is_phonemizer_available + from .utils import is_psutil_available as is_psutil_available + from .utils import is_py3nvml_available as is_py3nvml_available + from .utils import is_pyctcdecode_available as is_pyctcdecode_available + from .utils import is_sacremoses_available as is_sacremoses_available + from .utils import is_safetensors_available as is_safetensors_available + from .utils import is_sklearn_available as is_sklearn_available + from .utils import is_torch_hpu_available as is_torch_hpu_available + from .utils import is_torch_mlu_available as is_torch_mlu_available + from .utils import is_torch_musa_available as is_torch_musa_available + from .utils import is_torch_neuroncore_available as is_torch_neuroncore_available + from .utils import is_torch_npu_available as is_torch_npu_available + from .utils import is_torch_xla_available as is_torch_xla_available + from .utils import is_torch_xpu_available as is_torch_xpu_available + + # bitsandbytes config + from .utils.quantization_config import AqlmConfig as AqlmConfig + from .utils.quantization_config import AutoRoundConfig as AutoRoundConfig + from .utils.quantization_config import AwqConfig as AwqConfig + from .utils.quantization_config import BitNetQuantConfig as BitNetQuantConfig + from .utils.quantization_config import BitsAndBytesConfig as BitsAndBytesConfig + from .utils.quantization_config import CompressedTensorsConfig as CompressedTensorsConfig + from .utils.quantization_config import EetqConfig as EetqConfig + from .utils.quantization_config import FbgemmFp8Config as FbgemmFp8Config + from .utils.quantization_config import FineGrainedFP8Config as FineGrainedFP8Config + from .utils.quantization_config import FPQuantConfig as FPQuantConfig + from .utils.quantization_config import GPTQConfig as GPTQConfig + from .utils.quantization_config import HiggsConfig as HiggsConfig + from .utils.quantization_config import HqqConfig as HqqConfig + from .utils.quantization_config import QuantoConfig as QuantoConfig + from .utils.quantization_config import QuarkConfig as QuarkConfig + from .utils.quantization_config import SpQRConfig as SpQRConfig + from .utils.quantization_config import TorchAoConfig as TorchAoConfig + from .utils.quantization_config import VptqConfig as VptqConfig + from .video_processing_utils import BaseVideoProcessor as BaseVideoProcessor + +else: + import sys + + _import_structure = {k: set(v) for k, v in _import_structure.items()} + + import_structure = define_import_structure(Path(__file__).parent / "models", prefix="models") + import_structure[frozenset({})].update(_import_structure) + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + import_structure, + module_spec=__spec__, + extra_objects={"__version__": __version__}, + ) + + +if not is_tf_available() and not is_torch_available() and not is_flax_available(): + logger.warning_advice( + "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. " + "Models won't be available and only tokenizers, configuration " + "and file/data utilities can be used." + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/activations.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..7642e8aa238a6df701da95f0a58bef3156baf0e0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/activations.py @@ -0,0 +1,356 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import math +from collections import OrderedDict + +import torch +from torch import Tensor, nn + +from .integrations.hub_kernels import use_kernel_forward_from_hub +from .utils import logging +from .utils.import_utils import is_torchdynamo_compiling + + +logger = logging.get_logger(__name__) + + +@use_kernel_forward_from_hub("GeluTanh") +class GELUTanh(nn.Module): + """ + A fast C implementation of the tanh approximation of the GeLU activation function. See + https://huggingface.co/papers/1606.08415. + + This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical + match due to rounding errors. + """ + + def __init__(self, use_gelu_tanh_python: bool = False): + super().__init__() + if use_gelu_tanh_python: + self.act = self._gelu_tanh_python + else: + self.act = functools.partial(nn.functional.gelu, approximate="tanh") + + def _gelu_tanh_python(self, input: Tensor) -> Tensor: + return input * 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) + + def forward(self, input: Tensor) -> Tensor: + return self.act(input) + + +@use_kernel_forward_from_hub("NewGELU") +class NewGELUActivation(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415 + """ + + def forward(self, input: Tensor) -> Tensor: + return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) + + +@use_kernel_forward_from_hub("GeLU") +class GELUActivation(nn.Module): + """ + Original Implementation of the GELU activation function in Google BERT repo when initially created. For + information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 + + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional + Also see the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415 + """ + + def __init__(self, use_gelu_python: bool = False): + super().__init__() + if use_gelu_python: + self.act = self._gelu_python + else: + self.act = nn.functional.gelu + + def _gelu_python(self, input: Tensor) -> Tensor: + return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0))) + + def forward(self, input: Tensor) -> Tensor: + return self.act(input) + + +@use_kernel_forward_from_hub("SiLU") +class SiLUActivation(nn.Module): + """ + See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear + Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function + Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated + Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with + later. + """ + + def forward(self, input: Tensor) -> Tensor: + return nn.functional.silu(input) + + +@use_kernel_forward_from_hub("FastGELU") +class FastGELUActivation(nn.Module): + """ + Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs + """ + + def forward(self, input: Tensor) -> Tensor: + return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) + + +@use_kernel_forward_from_hub("QuickGELU") +class QuickGELUActivation(nn.Module): + """ + Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs + """ + + def forward(self, input: Tensor) -> Tensor: + return input * torch.sigmoid(1.702 * input) + + +class ClippedGELUActivation(nn.Module): + """ + Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as + it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to + https://huggingface.co/papers/2004.09602. + + Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when + initially created. + + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://huggingface.co/papers/1606.08415 + """ + + def __init__(self, min: float, max: float): + if min > max: + raise ValueError(f"min should be < max (got min: {min}, max: {max})") + + super().__init__() + self.min = min + self.max = max + + def forward(self, x: Tensor) -> Tensor: + return torch.clip(gelu(x), self.min, self.max) + + +class AccurateGELUActivation(nn.Module): + """ + Applies GELU approximation that is faster than default and more accurate than QuickGELU. See: + https://github.com/hendrycks/GELUs + + Implemented along with MEGA (Moving Average Equipped Gated Attention) + """ + + def __init__(self): + super().__init__() + self.precomputed_constant = math.sqrt(2 / math.pi) + + def forward(self, input: Tensor) -> Tensor: + return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3)))) + + +class MishActivation(nn.Module): + """ + See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://huggingface.co/papers/1908.08681). Also + visit the official repository for the paper: https://github.com/digantamisra98/Mish + """ + + def __init__(self): + super().__init__() + self.act = nn.functional.mish + + def _mish_python(self, input: Tensor) -> Tensor: + return input * torch.tanh(nn.functional.softplus(input)) + + def forward(self, input: Tensor) -> Tensor: + return self.act(input) + + +class LinearActivation(nn.Module): + """ + Applies the linear activation function, i.e. forwarding input directly to output. + """ + + def forward(self, input: Tensor) -> Tensor: + return input + + +class LaplaceActivation(nn.Module): + """ + Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See + https://huggingface.co/papers/2209.10655 + + Inspired by squared relu, but with bounded range and gradient for better stability + """ + + def forward(self, input, mu=0.707107, sigma=0.282095): + input = (input - mu).div(sigma * math.sqrt(2.0)) + return 0.5 * (1.0 + torch.erf(input)) + + +class ReLUSquaredActivation(nn.Module): + """ + Applies the relu^2 activation introduced in https://huggingface.co/papers/2109.08668v2 + """ + + def forward(self, input): + relu_applied = nn.functional.relu(input) + squared = torch.square(relu_applied) + return squared + + +class ClassInstantier(OrderedDict): + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +class XIELUActivation(nn.Module): + """ + Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010 + + If the user has installed the nickjbrowning/XIELU wheel, we import xIELU CUDA + Otherwise, we emit a single warning and use xIELU Python + """ + + def __init__( + self, + alpha_p_init=0.8, + alpha_n_init=0.8, + beta=0.5, + eps=-1e-6, + dtype=torch.bfloat16, + with_vector_loads=False, + ): + super().__init__() + self.alpha_p = nn.Parameter(torch.log(torch.expm1(torch.tensor(alpha_p_init, dtype=dtype))).unsqueeze(0)) + self.alpha_n = nn.Parameter( + torch.log(torch.expm1(torch.tensor(alpha_n_init - beta, dtype=dtype))).unsqueeze(0) + ) + self.register_buffer("beta", torch.tensor(beta, dtype=dtype)) + self.register_buffer("eps", torch.tensor(eps, dtype=dtype)) + self.with_vector_loads = with_vector_loads + # Temporary until xIELU CUDA fully implemented + self._beta_scalar = float(self.beta.detach().cpu().float().item()) + self._eps_scalar = float(self.eps.detach().cpu().float().item()) + + self._xielu_cuda_obj = None + try: + import xielu.ops # noqa: F401 + + self._xielu_cuda_obj = torch.classes.xielu.XIELU() + msg = "Using experimental xIELU CUDA." + try: + from torch._dynamo import allow_in_graph + + self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda) + msg += " Enabled torch._dynamo for xIELU CUDA." + except Exception as err: + msg += f" Could not enable torch._dynamo for xIELU ({err}) - this may result in slower performance." + self._xielu_cuda_fn = self._xielu_cuda + logger.warning_once(msg) + except Exception as err: + logger.warning_once( + "CUDA-fused xIELU not available (%s) – falling back to a Python version.\n" + "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`", + str(err), + ) + + def _xielu_python(self, x: Tensor) -> Tensor: + alpha_p = nn.functional.softplus(self.alpha_p) + alpha_n = self.beta + nn.functional.softplus(self.alpha_n) + return torch.where( + x > 0, + alpha_p * x * x + self.beta * x, + (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x, + ) + + def _xielu_cuda(self, x: Tensor) -> Tensor: + """Firewall function to prevent torch.compile from seeing .item() calls""" + original_shape = x.shape + # CUDA kernel expects 3D tensors, reshape if needed + while x.dim() < 3: + x = x.unsqueeze(0) + if x.dim() > 3: + x = x.view(-1, 1, x.size(-1)) + if original_shape != x.shape: + logger.warning_once( + "Warning: xIELU input tensor expects 3 dimensions but got (shape: %s). Reshaping to (shape: %s).", + original_shape, + x.shape, + ) + result = self._xielu_cuda_obj.forward( + x, + self.alpha_p.to(x.dtype), + self.alpha_n.to(x.dtype), + # Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item() + self._beta_scalar, + self._eps_scalar, + self.with_vector_loads, + ) + return result.view(original_shape) + + def forward(self, input: Tensor) -> Tensor: + if self._xielu_cuda_obj is not None and input.is_cuda: + if not is_torchdynamo_compiling(): + return self._xielu_cuda_fn(input) + else: + logger.warning_once("torch._dynamo is compiling, using Python version of xIELU.") + return self._xielu_python(input) + + +ACT2CLS = { + "gelu": GELUActivation, + "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}), + "gelu_fast": FastGELUActivation, + "gelu_new": NewGELUActivation, + "gelu_python": (GELUActivation, {"use_gelu_python": True}), + "gelu_pytorch_tanh": GELUTanh, + "gelu_python_tanh": (GELUTanh, {"use_gelu_tanh_python": True}), + "gelu_accurate": AccurateGELUActivation, + "laplace": LaplaceActivation, + "leaky_relu": nn.LeakyReLU, + "linear": LinearActivation, + "mish": MishActivation, + "quick_gelu": QuickGELUActivation, + "relu": nn.ReLU, + "relu2": ReLUSquaredActivation, + "relu6": nn.ReLU6, + "sigmoid": nn.Sigmoid, + "silu": SiLUActivation, + "swish": nn.SiLU, + "tanh": nn.Tanh, + "prelu": nn.PReLU, + "xielu": XIELUActivation, +} +ACT2FN = ClassInstantier(ACT2CLS) + + +def get_activation(activation_string): + if activation_string in ACT2FN: + return ACT2FN[activation_string] + else: + raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") + + +# For backwards compatibility with: from activations import gelu_python +gelu_python = get_activation("gelu_python") +gelu_new = get_activation("gelu_new") +gelu = get_activation("gelu") +gelu_fast = get_activation("gelu_fast") +quick_gelu = get_activation("quick_gelu") +silu = get_activation("silu") +mish = get_activation("mish") +linear_act = get_activation("linear") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/activations_tf.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/activations_tf.py new file mode 100644 index 0000000000000000000000000000000000000000..8dccf6c4f46b8fe1f98d7e57bd8611f660ed19f4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/activations_tf.py @@ -0,0 +1,147 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import tensorflow as tf +from packaging.version import parse + + +try: + import tf_keras as keras +except (ModuleNotFoundError, ImportError): + import keras + + if parse(keras.__version__).major > 2: + raise ValueError( + "Your currently installed version of Keras is Keras 3, but this is not yet supported in " + "Transformers. Please install the backwards-compatible tf-keras package with " + "`pip install tf-keras`." + ) + + +def _gelu(x): + """ + Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when + initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see + https://huggingface.co/papers/1606.08415 + """ + x = tf.convert_to_tensor(x) + cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype))) + + return x * cdf + + +def _gelu_new(x): + """ + Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://huggingface.co/papers/1606.0841 + + Args: + x: float Tensor to perform activation + + Returns: + `x` with the GELU activation applied. + """ + x = tf.convert_to_tensor(x) + pi = tf.cast(math.pi, x.dtype) + coeff = tf.cast(0.044715, x.dtype) + cdf = 0.5 * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3)))) + + return x * cdf + + +def mish(x): + x = tf.convert_to_tensor(x) + + return x * tf.tanh(tf.math.softplus(x)) + + +def gelu_fast(x): + x = tf.convert_to_tensor(x) + coeff1 = tf.cast(0.044715, x.dtype) + coeff2 = tf.cast(0.7978845608, x.dtype) + + return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x))) + + +def quick_gelu(x): + x = tf.convert_to_tensor(x) + coeff = tf.cast(1.702, x.dtype) + return x * tf.math.sigmoid(coeff * x) + + +def gelu_10(x): + """ + Clip the range of possible GeLU outputs between [-10, 10]. This is especially useful for quantization purpose, as + it allows mapping 2 negatives values in the GeLU spectrum. For more information on this trick, please refer to + https://huggingface.co/papers/2004.09602 + + Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when + initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see + https://huggingface.co/papers/1606.08415 :param x: :return: + """ + return tf.clip_by_value(_gelu(x), -10, 10) + + +def glu(x, axis=-1): + """ + Gated Linear Unit. Implementation as defined in the original paper (see https://huggingface.co/papers/1612.08083), where + the input `x` is split in two halves across a dimension (`axis`), A and B, returning A * sigmoid(B). + + Args: + `x`: float Tensor to perform activation + `axis`: dimension across which `x` be split in half + + Returns: + `x` with the GLU activation applied (with its size halved across the dimension `axis`). + """ + a, b = tf.split(x, 2, axis=axis) + return a * tf.math.sigmoid(b) + + +if parse(tf.version.VERSION) >= parse("2.4"): + + def approximate_gelu_wrap(x): + return keras.activations.gelu(x, approximate=True) + + gelu = keras.activations.gelu + gelu_new = approximate_gelu_wrap +else: + gelu = _gelu + gelu_new = _gelu_new + + +ACT2FN = { + "gelu": gelu, + "gelu_10": gelu_10, + "gelu_fast": gelu_fast, + "gelu_new": gelu_new, + "glu": glu, + "mish": mish, + "quick_gelu": quick_gelu, + "relu": keras.activations.relu, + "sigmoid": keras.activations.sigmoid, + "silu": keras.activations.swish, + "swish": keras.activations.swish, + "tanh": keras.activations.tanh, +} + + +def get_tf_activation(activation_string): + if activation_string in ACT2FN: + return ACT2FN[activation_string] + else: + raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/audio_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/audio_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5de56618014ecdb1443df0968018ce25dc394fc0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/audio_utils.py @@ -0,0 +1,1224 @@ +# Copyright 2023 The HuggingFace Inc. team and the librosa & torchaudio authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Audio processing functions to extract features from audio waveforms. This code is pure numpy to support all frameworks +and remove unnecessary dependencies. +""" + +import base64 +import importlib +import io +import os +import warnings +from collections.abc import Sequence +from io import BytesIO +from typing import TYPE_CHECKING, Any, Optional, Union + + +if TYPE_CHECKING: + import torch +import numpy as np +import requests +from packaging import version + +from .utils import ( + is_librosa_available, + is_numpy_array, + is_soundfile_available, + is_torch_tensor, + is_torchcodec_available, + requires_backends, +) + + +if is_soundfile_available(): + import soundfile as sf + +if is_librosa_available(): + import librosa + + # TODO: @eustlb, we actually don't need librosa but soxr is installed with librosa + import soxr + +if is_torchcodec_available(): + TORCHCODEC_VERSION = version.parse(importlib.metadata.version("torchcodec")) + +AudioInput = Union[np.ndarray, "torch.Tensor", Sequence[np.ndarray], Sequence["torch.Tensor"]] + + +def load_audio(audio: Union[str, np.ndarray], sampling_rate=16000, timeout=None) -> np.ndarray: + """ + Loads `audio` to an np.ndarray object. + + Args: + audio (`str` or `np.ndarray`): + The audio to be loaded to the numpy array format. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate to be used when loading the audio. It should be same as the + sampling rate the model you will be using further was trained with. + timeout (`float`, *optional*): + The timeout value in seconds for the URL request. + + Returns: + `np.ndarray`: A numpy array representing the audio. + """ + if isinstance(audio, str): + # Try to load with `torchcodec` but do not enforce users to install it. If not found + # fallback to `librosa`. If using an audio-only model, most probably `torchcodec` won't be + # needed. Do not raise any errors if not installed or versions do not match + if is_torchcodec_available() and TORCHCODEC_VERSION >= version.parse("0.3.0"): + audio = load_audio_torchcodec(audio, sampling_rate=sampling_rate) + else: + audio = load_audio_librosa(audio, sampling_rate=sampling_rate, timeout=timeout) + elif not isinstance(audio, np.ndarray): + raise TypeError( + "Incorrect format used for `audio`. Should be an url linking to an audio, a local path, or numpy array." + ) + return audio + + +def load_audio_torchcodec(audio: Union[str, np.ndarray], sampling_rate=16000) -> np.ndarray: + """ + Loads `audio` to an np.ndarray object using `torchcodec`. + + Args: + audio (`str` or `np.ndarray`): + The audio to be loaded to the numpy array format. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate to be used when loading the audio. It should be same as the + sampling rate the model you will be using further was trained with. + + Returns: + `np.ndarray`: A numpy array representing the audio. + """ + # Lazy import so that issues in torchcodec compatibility don't crash the whole library + requires_backends(load_audio_torchcodec, ["torchcodec"]) + from torchcodec.decoders import AudioDecoder + + # Set `num_channels` to `1` which is what most models expects and the default in librosa + decoder = AudioDecoder(audio, sample_rate=sampling_rate, num_channels=1) + audio = decoder.get_all_samples().data[0].numpy() # NOTE: feature extractors don't accept torch tensors + return audio + + +def load_audio_librosa(audio: Union[str, np.ndarray], sampling_rate=16000, timeout=None) -> np.ndarray: + """ + Loads `audio` to an np.ndarray object using `librosa`. + + Args: + audio (`str` or `np.ndarray`): + The audio to be loaded to the numpy array format. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate to be used when loading the audio. It should be same as the + sampling rate the model you will be using further was trained with. + timeout (`float`, *optional*): + The timeout value in seconds for the URL request. + + Returns: + `np.ndarray`: A numpy array representing the audio. + """ + requires_backends(load_audio_librosa, ["librosa"]) + + # Load audio from URL (e.g https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/translate_to_chinese.wav) + if audio.startswith("http://") or audio.startswith("https://"): + audio = librosa.load(BytesIO(requests.get(audio, timeout=timeout).content), sr=sampling_rate)[0] + elif os.path.isfile(audio): + audio = librosa.load(audio, sr=sampling_rate)[0] + return audio + + +def load_audio_as( + audio: str, + return_format: str, + timeout: Optional[int] = None, + force_mono: bool = False, + sampling_rate: Optional[int] = None, +) -> Union[str, dict[str, Any], io.BytesIO, None]: + """ + Load audio from either a local file path or URL and return in specified format. + + Args: + audio (`str`): Either a local file path or a URL to an audio file + return_format (`str`): Format to return the audio in: + - "base64": Base64 encoded string + - "dict": Dictionary with data and format + - "buffer": BytesIO object + timeout (`int`, *optional*): Timeout for URL requests in seconds + force_mono (`bool`): Whether to convert stereo audio to mono + sampling_rate (`int`, *optional*): If provided, the audio will be resampled to the specified sampling rate. + + Returns: + `Union[str, Dict[str, Any], io.BytesIO, None]`: + - `str`: Base64 encoded audio data (if return_format="base64") + - `dict`: Dictionary with 'data' (base64 encoded audio data) and 'format' keys (if return_format="dict") + - `io.BytesIO`: BytesIO object containing audio data (if return_format="buffer") + """ + # TODO: @eustlb, we actually don't need librosa but soxr is installed with librosa + requires_backends(load_audio_as, ["librosa"]) + + if return_format not in ["base64", "dict", "buffer"]: + raise ValueError(f"Invalid return_format: {return_format}. Must be 'base64', 'dict', or 'buffer'") + + try: + # Load audio bytes from URL or file + audio_bytes = None + if audio.startswith(("http://", "https://")): + response = requests.get(audio, timeout=timeout) + response.raise_for_status() + audio_bytes = response.content + elif os.path.isfile(audio): + with open(audio, "rb") as audio_file: + audio_bytes = audio_file.read() + else: + raise ValueError(f"File not found: {audio}") + + # Process audio data + with io.BytesIO(audio_bytes) as audio_file: + with sf.SoundFile(audio_file) as f: + audio_array = f.read(dtype="float32") + original_sr = f.samplerate + audio_format = f.format + if sampling_rate is not None and sampling_rate != original_sr: + # Resample audio to target sampling rate + audio_array = soxr.resample(audio_array, original_sr, sampling_rate, quality="HQ") + else: + sampling_rate = original_sr + + # Convert to mono if needed + if force_mono and audio_array.ndim != 1: + audio_array = audio_array.mean(axis=1) + + buffer = io.BytesIO() + sf.write(buffer, audio_array, sampling_rate, format=audio_format.upper()) + buffer.seek(0) + + if return_format == "buffer": + return buffer + elif return_format == "base64": + return base64.b64encode(buffer.read()).decode("utf-8") + elif return_format == "dict": + return { + "data": base64.b64encode(buffer.read()).decode("utf-8"), + "format": audio_format.lower(), + } + + except Exception as e: + raise ValueError(f"Error loading audio: {e}") + + +def is_valid_audio(audio): + return is_numpy_array(audio) or is_torch_tensor(audio) + + +def is_valid_list_of_audio(audio): + return audio and all(is_valid_audio(audio_i) for audio_i in audio) + + +def make_list_of_audio( + audio: Union[list[AudioInput], AudioInput], +) -> AudioInput: + """ + Ensure that the output is a list of audio. + Args: + audio (`Union[list[AudioInput], AudioInput]`): + The input audio. + Returns: + list: A list of audio. + """ + # If it's a list of audios, it's already in the right format + if isinstance(audio, (list, tuple)) and is_valid_list_of_audio(audio): + return audio + + # If it's a single audio, convert it to a list of + if is_valid_audio(audio): + return [audio] + + raise ValueError("Invalid input type. Must be a single audio or a list of audio") + + +def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]: + """ + Convert frequency from hertz to mels. + + Args: + freq (`float` or `np.ndarray`): + The frequency, or multiple frequencies, in hertz (Hz). + mel_scale (`str`, *optional*, defaults to `"htk"`): + The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`. + + Returns: + `float` or `np.ndarray`: The frequencies on the mel scale. + """ + + if mel_scale not in ["slaney", "htk", "kaldi"]: + raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".') + + if mel_scale == "htk": + return 2595.0 * np.log10(1.0 + (freq / 700.0)) + elif mel_scale == "kaldi": + return 1127.0 * np.log(1.0 + (freq / 700.0)) + + min_log_hertz = 1000.0 + min_log_mel = 15.0 + logstep = 27.0 / np.log(6.4) + mels = 3.0 * freq / 200.0 + + if isinstance(freq, np.ndarray): + log_region = freq >= min_log_hertz + mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep + elif freq >= min_log_hertz: + mels = min_log_mel + np.log(freq / min_log_hertz) * logstep + + return mels + + +def mel_to_hertz(mels: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]: + """ + Convert frequency from mels to hertz. + + Args: + mels (`float` or `np.ndarray`): + The frequency, or multiple frequencies, in mels. + mel_scale (`str`, *optional*, `"htk"`): + The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`. + + Returns: + `float` or `np.ndarray`: The frequencies in hertz. + """ + + if mel_scale not in ["slaney", "htk", "kaldi"]: + raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".') + + if mel_scale == "htk": + return 700.0 * (np.power(10, mels / 2595.0) - 1.0) + elif mel_scale == "kaldi": + return 700.0 * (np.exp(mels / 1127.0) - 1.0) + + min_log_hertz = 1000.0 + min_log_mel = 15.0 + logstep = np.log(6.4) / 27.0 + freq = 200.0 * mels / 3.0 + + if isinstance(mels, np.ndarray): + log_region = mels >= min_log_mel + freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel)) + elif mels >= min_log_mel: + freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel)) + + return freq + + +def hertz_to_octave(freq: Union[float, np.ndarray], tuning: float = 0.0, bins_per_octave: int = 12): + """ + Convert frequency from hertz to fractional octave numbers. + Adapted from *librosa*. + + Args: + freq (`float` or `np.ndarray`): + The frequency, or multiple frequencies, in hertz (Hz). + tuning (`float`, defaults to `0.`): + Tuning deviation from the Stuttgart pitch (A440) in (fractional) bins per octave. + bins_per_octave (`int`, defaults to `12`): + Number of bins per octave. + + Returns: + `float` or `np.ndarray`: The frequencies on the octave scale. + """ + stuttgart_pitch = 440.0 * 2.0 ** (tuning / bins_per_octave) + octave = np.log2(freq / (float(stuttgart_pitch) / 16)) + return octave + + +def _create_triangular_filter_bank(fft_freqs: np.ndarray, filter_freqs: np.ndarray) -> np.ndarray: + """ + Creates a triangular filter bank. + + Adapted from *torchaudio* and *librosa*. + + Args: + fft_freqs (`np.ndarray` of shape `(num_frequency_bins,)`): + Discrete frequencies of the FFT bins in Hz. + filter_freqs (`np.ndarray` of shape `(num_mel_filters,)`): + Center frequencies of the triangular filters to create, in Hz. + + Returns: + `np.ndarray` of shape `(num_frequency_bins, num_mel_filters)` + """ + filter_diff = np.diff(filter_freqs) + slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1) + down_slopes = -slopes[:, :-2] / filter_diff[:-1] + up_slopes = slopes[:, 2:] / filter_diff[1:] + return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes)) + + +def chroma_filter_bank( + num_frequency_bins: int, + num_chroma: int, + sampling_rate: int, + tuning: float = 0.0, + power: Optional[float] = 2.0, + weighting_parameters: Optional[tuple[float, float]] = (5.0, 2.0), + start_at_c_chroma: bool = True, +): + """ + Creates a chroma filter bank, i.e a linear transformation to project spectrogram bins onto chroma bins. + + Adapted from *librosa*. + + Args: + num_frequency_bins (`int`): + Number of frequencies used to compute the spectrogram (should be the same as in `stft`). + num_chroma (`int`): + Number of chroma bins (i.e pitch classes). + sampling_rate (`float`): + Sample rate of the audio waveform. + tuning (`float`): + Tuning deviation from A440 in fractions of a chroma bin. + power (`float`, *optional*, defaults to 2.0): + If 12.0, normalizes each column with their L2 norm. If 1.0, normalizes each column with their L1 norm. + weighting_parameters (`tuple[float, float]`, *optional*, defaults to `(5., 2.)`): + If specified, apply a Gaussian weighting parameterized by the first element of the tuple being the center and + the second element being the Gaussian half-width. + start_at_c_chroma (`bool`, *optional*, defaults to `True`): + If True, the filter bank will start at the 'C' pitch class. Otherwise, it will start at 'A'. + Returns: + `np.ndarray` of shape `(num_frequency_bins, num_chroma)` + """ + # Get the FFT bins, not counting the DC component + frequencies = np.linspace(0, sampling_rate, num_frequency_bins, endpoint=False)[1:] + + freq_bins = num_chroma * hertz_to_octave(frequencies, tuning=tuning, bins_per_octave=num_chroma) + + # make up a value for the 0 Hz bin = 1.5 octaves below bin 1 + # (so chroma is 50% rotated from bin 1, and bin width is broad) + freq_bins = np.concatenate(([freq_bins[0] - 1.5 * num_chroma], freq_bins)) + + bins_width = np.concatenate((np.maximum(freq_bins[1:] - freq_bins[:-1], 1.0), [1])) + + chroma_filters = np.subtract.outer(freq_bins, np.arange(0, num_chroma, dtype="d")).T + + num_chroma2 = np.round(float(num_chroma) / 2) + + # Project into range -num_chroma/2 .. num_chroma/2 + # add on fixed offset of 10*num_chroma to ensure all values passed to + # rem are positive + chroma_filters = np.remainder(chroma_filters + num_chroma2 + 10 * num_chroma, num_chroma) - num_chroma2 + + # Gaussian bumps - 2*D to make them narrower + chroma_filters = np.exp(-0.5 * (2 * chroma_filters / np.tile(bins_width, (num_chroma, 1))) ** 2) + + # normalize each column + if power is not None: + chroma_filters = chroma_filters / np.sum(chroma_filters**power, axis=0, keepdims=True) ** (1.0 / power) + + # Maybe apply scaling for fft bins + if weighting_parameters is not None: + center, half_width = weighting_parameters + chroma_filters *= np.tile( + np.exp(-0.5 * (((freq_bins / num_chroma - center) / half_width) ** 2)), + (num_chroma, 1), + ) + + if start_at_c_chroma: + chroma_filters = np.roll(chroma_filters, -3 * (num_chroma // 12), axis=0) + + # remove aliasing columns, copy to ensure row-contiguity + return np.ascontiguousarray(chroma_filters[:, : int(1 + num_frequency_bins / 2)]) + + +def mel_filter_bank( + num_frequency_bins: int, + num_mel_filters: int, + min_frequency: float, + max_frequency: float, + sampling_rate: int, + norm: Optional[str] = None, + mel_scale: str = "htk", + triangularize_in_mel_space: bool = False, +) -> np.ndarray: + """ + Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and + various implementation exist, which differ in the number of filters, the shape of the filters, the way the filters + are spaced, the bandwidth of the filters, and the manner in which the spectrum is warped. The goal of these + features is to approximate the non-linear human perception of the variation in pitch with respect to the frequency. + + Different banks of mel filters were introduced in the literature. The following variations are supported: + + - MFCC FB-20: introduced in 1980 by Davis and Mermelstein, it assumes a sampling frequency of 10 kHz and a speech + bandwidth of `[0, 4600]` Hz. + - MFCC FB-24 HTK: from the Cambridge HMM Toolkit (HTK) (1995) uses a filter bank of 24 filters for a speech + bandwidth of `[0, 8000]` Hz. This assumes sampling rate ≥ 16 kHz. + - MFCC FB-40: from the Auditory Toolbox for MATLAB written by Slaney in 1998, assumes a sampling rate of 16 kHz and + speech bandwidth of `[133, 6854]` Hz. This version also includes area normalization. + - HFCC-E FB-29 (Human Factor Cepstral Coefficients) of Skowronski and Harris (2004), assumes a sampling rate of + 12.5 kHz and speech bandwidth of `[0, 6250]` Hz. + + This code is adapted from *torchaudio* and *librosa*. Note that the default parameters of torchaudio's + `melscale_fbanks` implement the `"htk"` filters while librosa uses the `"slaney"` implementation. + + Args: + num_frequency_bins (`int`): + Number of frequency bins (should be the same as `n_fft // 2 + 1` where `n_fft` is the size of the Fourier Transform used to compute the spectrogram). + num_mel_filters (`int`): + Number of mel filters to generate. + min_frequency (`float`): + Lowest frequency of interest in Hz. + max_frequency (`float`): + Highest frequency of interest in Hz. This should not exceed `sampling_rate / 2`. + sampling_rate (`int`): + Sample rate of the audio waveform. + norm (`str`, *optional*): + If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization). + mel_scale (`str`, *optional*, defaults to `"htk"`): + The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`. + triangularize_in_mel_space (`bool`, *optional*, defaults to `False`): + If this option is enabled, the triangular filter is applied in mel space rather than frequency space. This + should be set to `true` in order to get the same results as `torchaudio` when computing mel filters. + + Returns: + `np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a + projection matrix to go from a spectrogram to a mel spectrogram. + """ + if norm is not None and norm != "slaney": + raise ValueError('norm must be one of None or "slaney"') + + if num_frequency_bins < 2: + raise ValueError(f"Require num_frequency_bins: {num_frequency_bins} >= 2") + + if min_frequency > max_frequency: + raise ValueError(f"Require min_frequency: {min_frequency} <= max_frequency: {max_frequency}") + + # center points of the triangular mel filters + mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale) + mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale) + mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2) + filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale) + + if triangularize_in_mel_space: + # frequencies of FFT bins in Hz, but filters triangularized in mel space + fft_bin_width = sampling_rate / ((num_frequency_bins - 1) * 2) + fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale) + filter_freqs = mel_freqs + else: + # frequencies of FFT bins in Hz + fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins) + + mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs) + + if norm is not None and norm == "slaney": + # Slaney-style mel is scaled to be approx constant energy per channel + enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters]) + mel_filters *= np.expand_dims(enorm, 0) + + if (mel_filters.max(axis=0) == 0.0).any(): + warnings.warn( + "At least one mel filter has all zero values. " + f"The value for `num_mel_filters` ({num_mel_filters}) may be set too high. " + f"Or, the value for `num_frequency_bins` ({num_frequency_bins}) may be set too low." + ) + + return mel_filters + + +def optimal_fft_length(window_length: int) -> int: + """ + Finds the best FFT input size for a given `window_length`. This function takes a given window length and, if not + already a power of two, rounds it up to the next power or two. + + The FFT algorithm works fastest when the length of the input is a power of two, which may be larger than the size + of the window or analysis frame. For example, if the window is 400 samples, using an FFT input size of 512 samples + is more optimal than an FFT size of 400 samples. Using a larger FFT size does not affect the detected frequencies, + it simply gives a higher frequency resolution (i.e. the frequency bins are smaller). + """ + return 2 ** int(np.ceil(np.log2(window_length))) + + +def window_function( + window_length: int, + name: str = "hann", + periodic: bool = True, + frame_length: Optional[int] = None, + center: bool = True, +) -> np.ndarray: + """ + Returns an array containing the specified window. This window is intended to be used with `stft`. + + The following window types are supported: + + - `"boxcar"`: a rectangular window + - `"hamming"`: the Hamming window + - `"hann"`: the Hann window + - `"povey"`: the Povey window + + Args: + window_length (`int`): + The length of the window in samples. + name (`str`, *optional*, defaults to `"hann"`): + The name of the window function. + periodic (`bool`, *optional*, defaults to `True`): + Whether the window is periodic or symmetric. + frame_length (`int`, *optional*): + The length of the analysis frames in samples. Provide a value for `frame_length` if the window is smaller + than the frame length, so that it will be zero-padded. + center (`bool`, *optional*, defaults to `True`): + Whether to center the window inside the FFT buffer. Only used when `frame_length` is provided. + + Returns: + `np.ndarray` of shape `(window_length,)` or `(frame_length,)` containing the window. + """ + length = window_length + 1 if periodic else window_length + + if name == "boxcar": + window = np.ones(length) + elif name in ["hamming", "hamming_window"]: + window = np.hamming(length) + elif name in ["hann", "hann_window"]: + window = np.hanning(length) + elif name == "povey": + window = np.power(np.hanning(length), 0.85) + else: + raise ValueError(f"Unknown window function '{name}'") + + if periodic: + window = window[:-1] + + if frame_length is None: + return window + + if window_length > frame_length: + raise ValueError( + f"Length of the window ({window_length}) may not be larger than frame_length ({frame_length})" + ) + + padded_window = np.zeros(frame_length) + offset = (frame_length - window_length) // 2 if center else 0 + padded_window[offset : offset + window_length] = window + return padded_window + + +# TODO This method does not support batching yet as we are mainly focused on inference. +def spectrogram( + waveform: np.ndarray, + window: np.ndarray, + frame_length: int, + hop_length: int, + fft_length: Optional[int] = None, + power: Optional[float] = 1.0, + center: bool = True, + pad_mode: str = "reflect", + onesided: bool = True, + dither: float = 0.0, + preemphasis: Optional[float] = None, + mel_filters: Optional[np.ndarray] = None, + mel_floor: float = 1e-10, + log_mel: Optional[str] = None, + reference: float = 1.0, + min_value: float = 1e-10, + db_range: Optional[float] = None, + remove_dc_offset: bool = False, + dtype: np.dtype = np.float32, +) -> np.ndarray: + """ + Calculates a spectrogram over one waveform using the Short-Time Fourier Transform. + + This function can create the following kinds of spectrograms: + + - amplitude spectrogram (`power = 1.0`) + - power spectrogram (`power = 2.0`) + - complex-valued spectrogram (`power = None`) + - log spectrogram (use `log_mel` argument) + - mel spectrogram (provide `mel_filters`) + - log-mel spectrogram (provide `mel_filters` and `log_mel`) + + How this works: + + 1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length + - hop_length` samples. + 2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`. + 3. The DFT is taken of each windowed frame. + 4. The results are stacked into a spectrogram. + + We make a distinction between the following "blocks" of sample data, each of which may have a different lengths: + + - The analysis frame. This is the size of the time slices that the input waveform is split into. + - The window. Each analysis frame is multiplied by the window to avoid spectral leakage. + - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram. + + In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A + padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame, + typically the next power of two. + + Note: This function is not optimized for speed yet. It should be mostly compatible with `librosa.stft` and + `torchaudio.functional.transforms.Spectrogram`, although it is more flexible due to the different ways spectrograms + can be constructed. + + Args: + waveform (`np.ndarray` of shape `(length,)`): + The input waveform. This must be a single real-valued, mono waveform. + window (`np.ndarray` of shape `(frame_length,)`): + The windowing function to apply, including zero-padding if necessary. The actual window length may be + shorter than `frame_length`, but we're assuming the array has already been zero-padded. + frame_length (`int`): + The length of the analysis frames in samples. With librosa this is always equal to `fft_length` but we also + allow smaller sizes. + hop_length (`int`): + The stride between successive analysis frames in samples. + fft_length (`int`, *optional*): + The size of the FFT buffer in samples. This determines how many frequency bins the spectrogram will have. + For optimal speed, this should be a power of two. If `None`, uses `frame_length`. + power (`float`, *optional*, defaults to 1.0): + If 1.0, returns the amplitude spectrogram. If 2.0, returns the power spectrogram. If `None`, returns + complex numbers. + center (`bool`, *optional*, defaults to `True`): + Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `False`, frame + `t` will start at time `t * hop_length`. + pad_mode (`str`, *optional*, defaults to `"reflect"`): + Padding mode used when `center` is `True`. Possible values are: `"constant"` (pad with zeros), `"edge"` + (pad with edge values), `"reflect"` (pads with mirrored values). + onesided (`bool`, *optional*, defaults to `True`): + If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1` + frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins. + dither (`float`, *optional*, defaults to 0.0): + Adds dithering. In other words, adds a small Gaussian noise to each frame. + E.g. use 4.0 to add dithering with a normal distribution centered + around 0.0 with standard deviation 4.0, 0.0 means no dithering. + Dithering has similar effect as `mel_floor`. It reduces the high log_mel_fbank + values for signals with hard-zero sections, when VAD cutoff is present in the signal. + preemphasis (`float`, *optional*) + Coefficient for a low-pass filter that applies pre-emphasis before the DFT. + mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*): + The mel filter bank. If supplied, applies a this filter bank to create a mel spectrogram. + mel_floor (`float`, *optional*, defaults to 1e-10): + Minimum value of mel frequency banks. + log_mel (`str`, *optional*): + How to convert the spectrogram to log scale. Possible options are: `None` (don't convert), `"log"` (take + the natural logarithm) `"log10"` (take the base-10 logarithm), `"dB"` (convert to decibels). Can only be + used when `power` is not `None`. + reference (`float`, *optional*, defaults to 1.0): + Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set + the loudest part to 0 dB. Must be greater than zero. + min_value (`float`, *optional*, defaults to `1e-10`): + The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking + `log(0)`. For a power spectrogram, the default of `1e-10` corresponds to a minimum of -100 dB. For an + amplitude spectrogram, the value `1e-5` corresponds to -100 dB. Must be greater than zero. + db_range (`float`, *optional*): + Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the + peak value and the smallest value will never be more than 80 dB. Must be greater than zero. + remove_dc_offset (`bool`, *optional*): + Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in + order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters. + dtype (`np.dtype`, *optional*, defaults to `np.float32`): + Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be + `np.complex64`. + + Returns: + `nd.array` containing a spectrogram of shape `(num_frequency_bins, length)` for a regular spectrogram or shape + `(num_mel_filters, length)` for a mel spectrogram. + """ + window_length = len(window) + + if fft_length is None: + fft_length = frame_length + + if frame_length > fft_length: + raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})") + + if window_length != frame_length: + raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})") + + if hop_length <= 0: + raise ValueError("hop_length must be greater than zero") + + if waveform.ndim != 1: + raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}") + + if np.iscomplexobj(waveform): + raise ValueError("Complex-valued input waveforms are not currently supported") + + if power is None and mel_filters is not None: + raise ValueError( + "You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram." + "Specify `power` to fix this issue." + ) + + # center pad the waveform + if center: + padding = [(int(frame_length // 2), int(frame_length // 2))] + waveform = np.pad(waveform, padding, mode=pad_mode) + + # promote to float64, since np.fft uses float64 internally + waveform = waveform.astype(np.float64) + window = window.astype(np.float64) + + # split waveform into frames of frame_length size + num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length)) + + num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length + spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64) + + # rfft is faster than fft + fft_func = np.fft.rfft if onesided else np.fft.fft + buffer = np.zeros(fft_length) + + timestep = 0 + for frame_idx in range(num_frames): + buffer[:frame_length] = waveform[timestep : timestep + frame_length] + + if dither != 0.0: + buffer[:frame_length] += dither * np.random.randn(frame_length) + + if remove_dc_offset: + buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean() + + if preemphasis is not None: + buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1] + buffer[0] *= 1 - preemphasis + + buffer[:frame_length] *= window + + spectrogram[frame_idx] = fft_func(buffer) + timestep += hop_length + + # note: ** is much faster than np.power + if power is not None: + spectrogram = np.abs(spectrogram, dtype=np.float64) ** power + + spectrogram = spectrogram.T + + if mel_filters is not None: + spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram)) + + if power is not None and log_mel is not None: + if log_mel == "log": + spectrogram = np.log(spectrogram) + elif log_mel == "log10": + spectrogram = np.log10(spectrogram) + elif log_mel == "dB": + if power == 1.0: + spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range) + elif power == 2.0: + spectrogram = power_to_db(spectrogram, reference, min_value, db_range) + else: + raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}") + else: + raise ValueError(f"Unknown log_mel option: {log_mel}") + + spectrogram = np.asarray(spectrogram, dtype) + + return spectrogram + + +def spectrogram_batch( + waveform_list: list[np.ndarray], + window: np.ndarray, + frame_length: int, + hop_length: int, + fft_length: Optional[int] = None, + power: Optional[float] = 1.0, + center: bool = True, + pad_mode: str = "reflect", + onesided: bool = True, + dither: float = 0.0, + preemphasis: Optional[float] = None, + mel_filters: Optional[np.ndarray] = None, + mel_floor: float = 1e-10, + log_mel: Optional[str] = None, + reference: float = 1.0, + min_value: float = 1e-10, + db_range: Optional[float] = None, + remove_dc_offset: bool = False, + dtype: np.dtype = np.float32, +) -> list[np.ndarray]: + """ + Calculates spectrograms for a list of waveforms using the Short-Time Fourier Transform, optimized for batch processing. + This function extends the capabilities of the `spectrogram` function to handle multiple waveforms efficiently by leveraging broadcasting. + + It supports generating various types of spectrograms: + + - amplitude spectrogram (`power = 1.0`) + - power spectrogram (`power = 2.0`) + - complex-valued spectrogram (`power = None`) + - log spectrogram (use `log_mel` argument) + - mel spectrogram (provide `mel_filters`) + - log-mel spectrogram (provide `mel_filters` and `log_mel`) + + How this works: + + 1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length + - hop_length` samples. + 2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`. + 3. The DFT is taken of each windowed frame. + 4. The results are stacked into a spectrogram. + + We make a distinction between the following "blocks" of sample data, each of which may have a different lengths: + + - The analysis frame. This is the size of the time slices that the input waveform is split into. + - The window. Each analysis frame is multiplied by the window to avoid spectral leakage. + - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram. + + In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A + padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame, + typically the next power of two. + + Note: This function is designed for efficient batch processing of multiple waveforms but retains compatibility with individual waveform processing methods like `librosa.stft`. + + Args: + waveform_list (`list[np.ndarray]` with arrays of shape `(length,)`): + The list of input waveforms, each a single-channel (mono) signal. + window (`np.ndarray` of shape `(frame_length,)`): + The windowing function to apply, including zero-padding if necessary. + frame_length (`int`): + The length of each frame for analysis. + hop_length (`int`): + The step size between successive frames. + fft_length (`int`, *optional*): + The size of the FFT buffer, defining frequency bin resolution. + power (`float`, *optional*, defaults to 1.0): + Determines the type of spectrogram: 1.0 for amplitude, 2.0 for power, None for complex. + center (`bool`, *optional*, defaults to `True`): + Whether to center-pad the waveform frames. + pad_mode (`str`, *optional*, defaults to `"reflect"`): + The padding strategy when `center` is `True`. + onesided (`bool`, *optional*, defaults to `True`): + If True, returns a one-sided spectrogram for real input signals. + dither (`float`, *optional*, defaults to 0.0): + Adds dithering. In other words, adds a small Gaussian noise to each frame. + E.g. use 4.0 to add dithering with a normal distribution centered + around 0.0 with standard deviation 4.0, 0.0 means no dithering. + preemphasis (`float`, *optional*): + Applies a pre-emphasis filter to each frame. + mel_filters (`np.ndarray`, *optional*): + Mel filter bank for converting to mel spectrogram. + mel_floor (`float`, *optional*, defaults to 1e-10): + Floor value for mel spectrogram to avoid log(0). + log_mel (`str`, *optional*): + Specifies log scaling strategy; options are None, "log", "log10", "dB". + reference (`float`, *optional*, defaults to 1.0): + Reference value for dB conversion in log_mel. + min_value (`float`, *optional*, defaults to 1e-10): + Minimum floor value for log scale conversions. + db_range (`float`, *optional*): + Dynamic range for dB scale spectrograms. + remove_dc_offset (`bool`, *optional*): + Whether to remove the DC offset from each frame. + dtype (`np.dtype`, *optional*, defaults to `np.float32`): + Data type of the output spectrogram. + + Returns: + list[`np.ndarray`]: A list of spectrogram arrays, one for each input waveform. + """ + window_length = len(window) + + if fft_length is None: + fft_length = frame_length + + if frame_length > fft_length: + raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})") + + if window_length != frame_length: + raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})") + + if hop_length <= 0: + raise ValueError("hop_length must be greater than zero") + + # Check the dimensions of the waveform , and if waveform is complex + for waveform in waveform_list: + if waveform.ndim != 1: + raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}") + if np.iscomplexobj(waveform): + raise ValueError("Complex-valued input waveforms are not currently supported") + # Center pad the waveform + if center: + padding = [(int(frame_length // 2), int(frame_length // 2))] + waveform_list = [ + np.pad( + waveform, + padding, + mode=pad_mode, + ) + for waveform in waveform_list + ] + original_waveform_lengths = [ + len(waveform) for waveform in waveform_list + ] # these lengths will be used to remove padding later + + # Batch pad the waveform + max_length = max(original_waveform_lengths) + padded_waveform_batch = np.array( + [ + np.pad(waveform, (0, max_length - len(waveform)), mode="constant", constant_values=0) + for waveform in waveform_list + ], + dtype=dtype, + ) + + # Promote to float64, since np.fft uses float64 internally + padded_waveform_batch = padded_waveform_batch.astype(np.float64) + window = window.astype(np.float64) + + # Split waveform into frames of frame_length size + num_frames = int(1 + np.floor((padded_waveform_batch.shape[1] - frame_length) / hop_length)) + # these lengths will be used to remove padding later + true_num_frames = [int(1 + np.floor((length - frame_length) / hop_length)) for length in original_waveform_lengths] + num_batches = padded_waveform_batch.shape[0] + + num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length + spectrogram = np.empty((num_batches, num_frames, num_frequency_bins), dtype=np.complex64) + + # rfft is faster than fft + fft_func = np.fft.rfft if onesided else np.fft.fft + buffer = np.zeros((num_batches, fft_length)) + + for frame_idx in range(num_frames): + timestep = frame_idx * hop_length + buffer[:, :frame_length] = padded_waveform_batch[:, timestep : timestep + frame_length] + + if dither != 0.0: + buffer[:, :frame_length] += dither * np.random.randn(*buffer[:, :frame_length].shape) + + if remove_dc_offset: + buffer[:, :frame_length] -= buffer[:, :frame_length].mean(axis=1, keepdims=True) + + if preemphasis is not None: + buffer[:, 1:frame_length] -= preemphasis * buffer[:, : frame_length - 1] + buffer[:, 0] *= 1 - preemphasis + + buffer[:, :frame_length] *= window + + spectrogram[:, frame_idx] = fft_func(buffer) + + # Note: ** is much faster than np.power + if power is not None: + spectrogram = np.abs(spectrogram, dtype=np.float64) ** power + + # Apply mel filters if provided + if mel_filters is not None: + result = np.tensordot(spectrogram, mel_filters.T, axes=([2], [1])) + spectrogram = np.maximum(mel_floor, result) + + # Convert to log scale if specified + if power is not None and log_mel is not None: + if log_mel == "log": + spectrogram = np.log(spectrogram) + elif log_mel == "log10": + spectrogram = np.log10(spectrogram) + elif log_mel == "dB": + if power == 1.0: + spectrogram = amplitude_to_db_batch(spectrogram, reference, min_value, db_range) + elif power == 2.0: + spectrogram = power_to_db_batch(spectrogram, reference, min_value, db_range) + else: + raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}") + else: + raise ValueError(f"Unknown log_mel option: {log_mel}") + + spectrogram = np.asarray(spectrogram, dtype) + + spectrogram_list = [spectrogram[i, : true_num_frames[i], :].T for i in range(len(true_num_frames))] + + return spectrogram_list + + +def power_to_db( + spectrogram: np.ndarray, + reference: float = 1.0, + min_value: float = 1e-10, + db_range: Optional[float] = None, +) -> np.ndarray: + """ + Converts a power spectrogram to the decibel scale. This computes `10 * log10(spectrogram / reference)`, using basic + logarithm properties for numerical stability. + + The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a + linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it. + This means that large variations in energy may not sound all that different if the sound is loud to begin with. + This compression operation makes the (mel) spectrogram features match more closely what humans actually hear. + + Based on the implementation of `librosa.power_to_db`. + + Args: + spectrogram (`np.ndarray`): + The input power (mel) spectrogram. Note that a power spectrogram has the amplitudes squared! + reference (`float`, *optional*, defaults to 1.0): + Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set + the loudest part to 0 dB. Must be greater than zero. + min_value (`float`, *optional*, defaults to `1e-10`): + The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking + `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero. + db_range (`float`, *optional*): + Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the + peak value and the smallest value will never be more than 80 dB. Must be greater than zero. + + Returns: + `np.ndarray`: the spectrogram in decibels + """ + if reference <= 0.0: + raise ValueError("reference must be greater than zero") + if min_value <= 0.0: + raise ValueError("min_value must be greater than zero") + + reference = max(min_value, reference) + + spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None) + spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference)) + + if db_range is not None: + if db_range <= 0.0: + raise ValueError("db_range must be greater than zero") + spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None) + + return spectrogram + + +def power_to_db_batch( + spectrogram: np.ndarray, + reference: float = 1.0, + min_value: float = 1e-10, + db_range: Optional[float] = None, +) -> np.ndarray: + """ + Converts a batch of power spectrograms to the decibel scale. This computes `10 * log10(spectrogram / reference)`, + using basic logarithm properties for numerical stability. + + This function supports batch processing, where each item in the batch is an individual power (mel) spectrogram. + + Args: + spectrogram (`np.ndarray`): + The input batch of power (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape). + Note that a power spectrogram has the amplitudes squared! + reference (`float`, *optional*, defaults to 1.0): + Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set + the loudest part to 0 dB. Must be greater than zero. + min_value (`float`, *optional*, defaults to `1e-10`): + The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking + `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero. + db_range (`float`, *optional*): + Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the + peak value and the smallest value will never be more than 80 dB. Must be greater than zero. + + Returns: + `np.ndarray`: the batch of spectrograms in decibels + """ + if reference <= 0.0: + raise ValueError("reference must be greater than zero") + if min_value <= 0.0: + raise ValueError("min_value must be greater than zero") + + reference = max(min_value, reference) + + spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None) + spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference)) + + if db_range is not None: + if db_range <= 0.0: + raise ValueError("db_range must be greater than zero") + # Apply db_range clipping per batch item + max_values = spectrogram.max(axis=(1, 2), keepdims=True) + spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None) + + return spectrogram + + +def amplitude_to_db( + spectrogram: np.ndarray, + reference: float = 1.0, + min_value: float = 1e-5, + db_range: Optional[float] = None, +) -> np.ndarray: + """ + Converts an amplitude spectrogram to the decibel scale. This computes `20 * log10(spectrogram / reference)`, using + basic logarithm properties for numerical stability. + + The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a + linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it. + This means that large variations in energy may not sound all that different if the sound is loud to begin with. + This compression operation makes the (mel) spectrogram features match more closely what humans actually hear. + + Args: + spectrogram (`np.ndarray`): + The input amplitude (mel) spectrogram. + reference (`float`, *optional*, defaults to 1.0): + Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set + the loudest part to 0 dB. Must be greater than zero. + min_value (`float`, *optional*, defaults to `1e-5`): + The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking + `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero. + db_range (`float`, *optional*): + Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the + peak value and the smallest value will never be more than 80 dB. Must be greater than zero. + + Returns: + `np.ndarray`: the spectrogram in decibels + """ + if reference <= 0.0: + raise ValueError("reference must be greater than zero") + if min_value <= 0.0: + raise ValueError("min_value must be greater than zero") + + reference = max(min_value, reference) + + spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None) + spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference)) + + if db_range is not None: + if db_range <= 0.0: + raise ValueError("db_range must be greater than zero") + spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None) + + return spectrogram + + +def amplitude_to_db_batch( + spectrogram: np.ndarray, reference: float = 1.0, min_value: float = 1e-5, db_range: Optional[float] = None +) -> np.ndarray: + """ + Converts a batch of amplitude spectrograms to the decibel scale. This computes `20 * log10(spectrogram / reference)`, + using basic logarithm properties for numerical stability. + + The function supports batch processing, where each item in the batch is an individual amplitude (mel) spectrogram. + + Args: + spectrogram (`np.ndarray`): + The input batch of amplitude (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape). + reference (`float`, *optional*, defaults to 1.0): + Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set + the loudest part to 0 dB. Must be greater than zero. + min_value (`float`, *optional*, defaults to `1e-5`): + The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking + `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero. + db_range (`float`, *optional*): + Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the + peak value and the smallest value will never be more than 80 dB. Must be greater than zero. + + Returns: + `np.ndarray`: the batch of spectrograms in decibels + """ + if reference <= 0.0: + raise ValueError("reference must be greater than zero") + if min_value <= 0.0: + raise ValueError("min_value must be greater than zero") + + reference = max(min_value, reference) + + spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None) + spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference)) + + if db_range is not None: + if db_range <= 0.0: + raise ValueError("db_range must be greater than zero") + # Apply db_range clipping per batch item + max_values = spectrogram.max(axis=(1, 2), keepdims=True) + spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None) + + return spectrogram diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/cache_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/cache_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..99beb0b610a1f88dbd2dadf473f829c0be12d8cc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/cache_utils.py @@ -0,0 +1,1493 @@ +from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +from .configuration_utils import PretrainedConfig +from .utils import ( + is_hqq_available, + is_quanto_greater, + is_torch_greater_or_equal, + is_torchdynamo_compiling, + logging, +) + + +if is_hqq_available(): + from hqq.core.quantize import Quantizer as HQQQuantizer + +_is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True) + + +logger = logging.get_logger(__name__) + + +class CacheLayerMixin(ABC): + """Base, abstract class for a single layer's cache.""" + + is_compileable = False + + def __init__(self): + self.keys: Optional[torch.Tensor] = None + self.values: Optional[torch.Tensor] = None + self.is_initialized = False + + def __repr__(self): + return f"{self.__class__.__name__}" + + @abstractmethod + def lazy_initialization(self, key_states: torch.Tensor): ... + + @abstractmethod + def update( + self, key_states: torch.Tensor, value_states: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None + ) -> tuple[torch.Tensor, torch.Tensor]: ... + + @abstractmethod + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: ... + + @abstractmethod + def get_seq_length(self) -> int: ... + + @abstractmethod + def get_max_cache_shape(self) -> int: ... + + def offload(self): + """Offload this layer's data to CPU device.""" + if self.is_initialized: + self.keys = self.keys.to("cpu", non_blocking=True) + self.values = self.values.to("cpu", non_blocking=True) + + def prefetch(self): + """In case of layer offloading, this allows to move the data back to the layer's device ahead of time.""" + if self.is_initialized and self.keys.device != self.device: + self.keys = self.keys.to(self.device, non_blocking=True) + self.values = self.values.to(self.device, non_blocking=True) + + def reset(self) -> None: + """Resets the cache values while preserving the objects""" + if self.is_initialized: + self.keys.zero_() + self.values.zero_() + # This attribute is set on several Layers + if hasattr(self, "cumulative_length"): + self.cumulative_length = 0 + + def reorder_cache(self, beam_idx: torch.LongTensor) -> None: + """Reorders this layer's cache for beam search.""" + if self.get_seq_length() > 0: + self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device)) + self.values = self.values.index_select(0, beam_idx.to(self.values.device)) + + +class DynamicLayer(CacheLayerMixin): + """ + A cache layer that grows dynamically as more tokens are generated. This is the default for generative models. + It stores the key and value states as tensors of shape `[batch_size, num_heads, seq_len, head_dim]`. + """ + + is_sliding = False + + def lazy_initialization(self, key_states: torch.Tensor): + self.dtype, self.device = key_states.dtype, key_states.device + self.keys = torch.tensor([], dtype=self.dtype, device=self.device) + self.values = torch.tensor([], dtype=self.dtype, device=self.device) + self.is_initialized = True + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Update the key and value caches in-place, and return the necessary keys and value states. + + Args: + key_states (`torch.Tensor`): The new key states to cache. + value_states (`torch.Tensor`): The new value states to cache. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. + + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states. + """ + # Lazy initialization + if not self.is_initialized: + self.lazy_initialization(key_states) + + self.keys = torch.cat([self.keys, key_states], dim=-2) + self.values = torch.cat([self.values, value_states], dim=-2) + return self.keys, self.values + + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + """Return the length and offset of the cache, used to generate the mask""" + kv_offset = 0 + query_length = cache_position.shape[0] + kv_length = self.get_seq_length() + query_length + return kv_length, kv_offset + + def get_seq_length(self) -> int: + """Returns the sequence length of the cached states.""" + if not self.is_initialized or self.keys.numel() == 0: + return 0 + return self.keys.shape[-2] + + def get_max_cache_shape(self) -> int: + """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length.""" + return -1 + + def crop(self, max_length: int) -> None: + """ + Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be negative + to remove `max_length` tokens. + """ + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + self.keys = self.keys[..., :max_length, :] + self.values = self.values[..., :max_length, :] + + def batch_repeat_interleave(self, repeats: int) -> None: + """Repeat the cache `repeats` times in the batch dimension.""" + if self.get_seq_length() > 0: + self.keys = self.keys.repeat_interleave(repeats, dim=0) + self.values = self.values.repeat_interleave(repeats, dim=0) + + def batch_select_indices(self, indices: torch.Tensor) -> None: + """Only keep the `indices` in the batch dimension of the cache.""" + if self.get_seq_length() > 0: + self.keys = self.keys[indices, ...] + self.values = self.values[indices, ...] + + +class DynamicSlidingWindowLayer(DynamicLayer): + """ + A cache layer that grows dynamically as more tokens are generated, up until the sliding window size. + It stores the key and value states as tensors of shape `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`. + """ + + is_sliding = True + + def __init__(self, sliding_window: int): + super().__init__() + self.sliding_window = sliding_window + self.cumulative_length = 0 + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Update the key and value caches in-place, and return the necessary keys and value states. + + Args: + key_states (`torch.Tensor`): The new key states to cache. + value_states (`torch.Tensor`): The new value states to cache. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. + + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states. + """ + # Lazy initialization + if not self.is_initialized: + self.lazy_initialization(key_states) + + self.cumulative_length += key_states.shape[-2] + + # Compute the full states + full_key_states = torch.cat([self.keys, key_states], dim=-2) + full_value_states = torch.cat([self.values, value_states], dim=-2) + # Only cache the last `self.sliding_window - 1` tokens (or all of them if lower than that) + self.keys = full_key_states[:, :, -self.sliding_window + 1 :, :] + self.values = full_value_states[:, :, -self.sliding_window + 1 :, :] + + # Return the full states + return full_key_states, full_value_states + + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + """Return the length and offset of the cache, used to generate the attention mask""" + query_length = cache_position.shape[0] + is_full = self.cumulative_length >= self.sliding_window + + kv_offset = max(self.cumulative_length - self.sliding_window + 1, 0) + if is_full: + kv_length = self.sliding_window - 1 + query_length + else: + kv_length = self.cumulative_length + query_length + + return kv_length, kv_offset + + def get_seq_length(self) -> int: + """Returns the sequence length of the cached states.""" + return self.cumulative_length + + def get_max_cache_shape(self) -> int: + """Return the maximum cache shape of the cache""" + return self.sliding_window + + def crop(self, max_length: int) -> None: + """ + Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be + negative to remove `max_length` tokens. + """ + if self.get_seq_length() >= self.sliding_window: + raise ValueError( + "Cannot `crop` a `DynamicSlidingWindowLayer` after it has seen more tokens than its" + "sliding window (otherwise some states are lost)" + ) + super().crop(max_length) + self.cumulative_length = self.keys.shape[-2] + + +class StaticLayer(CacheLayerMixin): + """ + A static cache layer that stores the key and value states as static tensors of shape `[batch_size, num_heads, max_cache_len), head_dim]`. + It lazily allocates its full backing tensors, and then mutates them in-place. Built for `torch.compile` support. + + Args: + max_cache_len (`int`): + Maximum number of tokens that can be stored, used for tensor preallocation. + """ + + is_compileable = True + is_sliding = False + + def __init__(self, max_cache_len: int): + super().__init__() + self.max_cache_len = max_cache_len + + def lazy_initialization(self, key_states: torch.Tensor): + """ + Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device, + num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving + devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well). + + If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this + function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we + internally don't compile the prefill, this is guaranteed to have been called already when compiling. + If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache, + it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs, + i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should + not be compiled anyway for performances! + """ + self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape + self.dtype, self.device = key_states.dtype, key_states.device + + self.keys = torch.zeros( + (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim), + dtype=self.dtype, + device=self.device, + ) + self.values = torch.zeros( + (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim), + dtype=self.dtype, + device=self.device, + ) + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph + # breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case. + # As prefill should never be compiled, this is not an issue and it will still be run (except when users compile + # prefill explicitly, but this should be avoided!) + if not is_torchdynamo_compiling(): + torch._dynamo.mark_static_address(self.keys) + torch._dynamo.mark_static_address(self.values) + + self.is_initialized = True + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Update the key and value caches in-place, and return the necessary keys and value states. + + Args: + key_states (`torch.Tensor`): The new key states to cache. + value_states (`torch.Tensor`): The new value states to cache. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. + + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states. + """ + # Lazy initialization + if not self.is_initialized: + self.lazy_initialization(key_states) + + # Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention, + # in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len) + cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None + cache_position = ( + cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device) + ) + + # Update the cache + try: + self.keys.index_copy_(2, cache_position, key_states) + self.values.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # Fallback for devices like MPS where index_copy_ might not be supported. + self.keys[:, :, cache_position] = key_states + self.values[:, :, cache_position] = value_states + return self.keys, self.values + + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + """Return the length and offset of the cache, used to generate the attention mask""" + kv_offset = 0 + kv_length = self.max_cache_len + return kv_length, kv_offset + + def get_seq_length(self) -> int: + """Returns the sequence length of the cached states.""" + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + return (self.keys[0, 0].any(dim=-1)).sum() if self.is_initialized else 0 + + def get_max_cache_shape(self) -> int: + """Return the maximum cache shape of the cache""" + return self.max_cache_len + + +class StaticSlidingWindowLayer(StaticLayer): + """ + A static cache layer that stores the key and value states as static tensors of shape + `[batch_size, num_heads, min(max_cache_len, sliding_window), head_dim]`. It lazily allocates its full backing + tensors, and then mutates them in-place. Built for `torch.compile` support. + + Args: + max_cache_len (`int`): + Maximum number of tokens that can be stored, used for tensor preallocation. + sliding_window (`int`): + The size of the sliding window. + """ + + is_sliding = True + + def __init__(self, max_cache_len: int, sliding_window: int): + effective_max_cache_len = min(sliding_window, max_cache_len) + super().__init__(max_cache_len=effective_max_cache_len) + self.cumulative_length = 0 + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Update the key and value caches in-place, and return the necessary keys and value states. + + Args: + key_states (`torch.Tensor`): The new key states to cache. + value_states (`torch.Tensor`): The new value states to cache. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. + + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states. + """ + # Lazy initialization + if not self.is_initialized: + self.lazy_initialization(key_states) + + # Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention, + # in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len) + cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None + cache_position = ( + cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device) + ) + + cumulative_length = self.cumulative_length + is_full = cumulative_length >= self.max_cache_len + # Update it now that we saved the value above + self.cumulative_length += key_states.shape[-2] + + if is_full: + # In general, we should use a much simpler `cat` here as well, independently of the states size. However, + # dynamo is currently bugged when doing it - see https://github.com/pytorch/pytorch/issues/159855 for more details + if key_states.shape[-2] == 1: + # Roll all values to the left by 1 position + new_keys = self.keys.roll(-1, dims=-2) + new_values = self.values.roll(-1, dims=-2) + # Overwrite the last position with new states + # (note: very important to use a tensor to index here, see https://github.com/pytorch/pytorch/issues/159855) + index = torch.tensor([-1], dtype=int, device=self.device) + new_keys[:, :, index] = key_states + new_values[:, :, index] = value_states + + # Copy back into `self` (do not just assign again) in order to keep the static dynamo address + self.keys.copy_(new_keys) + self.values.copy_(new_values) + # Very important to return the `self` tensors here, as they have the static dynamo address + return self.keys, self.values + # Already full but using more than 1 new token (e.g. prefill caching, chat continuation, etc...) + else: + full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2) + full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2) + # Not yet full, but becoming full on this update + elif cumulative_length + key_states.shape[2] > self.max_cache_len: + # Fast prefill path, no need to cat() in this case, as the cache is currently empty + if cumulative_length == 0: + full_key_states = key_states + full_value_states = value_states + else: + full_key_states = torch.cat((self.keys[:, :, :cumulative_length, :], key_states), dim=-2) + full_value_states = torch.cat((self.values[:, :, :cumulative_length, :], value_states), dim=-2) + else: + try: + self.keys.index_copy_(2, cache_position, key_states) + self.values.index_copy_(2, cache_position, value_states) + except NotImplementedError: + self.keys[:, :, cache_position] = key_states + self.values[:, :, cache_position] = value_states + + # Very important to return the `self` tensors here, as they have the static dynamo address + return self.keys, self.values + + # We only cache the last `sliding_window` tokens + self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :]) + self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :]) + # we should return the whole states instead of `self.keys/values` here, as otherwise we lose some context + return full_key_states, full_value_states + + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + """Return the length and offset of the cache, used to generate the attention mask""" + query_length = cache_position.shape[0] + sliding_window = self.max_cache_len + is_full = self.cumulative_length >= self.max_cache_len + + kv_offset = max(self.cumulative_length - sliding_window + 1, 0) + # The cache is already full + if is_full: + kv_length = sliding_window + query_length - 1 + # Not yet full, but becoming full on this update + elif self.cumulative_length + query_length > sliding_window: + kv_length = self.cumulative_length + query_length + # Here the Cache is still smaller than the local size, but we return the local size as it's static + else: + kv_length = sliding_window + + return kv_length, kv_offset + + def get_seq_length(self) -> int: + """Returns the sequence length of the cached states.""" + return self.cumulative_length + + +class QuantizedLayer(DynamicLayer): + """ + A quantized layer similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for the key and value caches by + applying quantization. + + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` + is set as a maximum capacity for the original precision cache. When the length goes beyond maximum capacity, the original + precision cache is discarded and moved into the quantized cache. The quantization is done per-channel with a set `q_group_size` + for both Keys and Values, in contrast to what was described in the paper. + """ + + def __init__( + self, + nbits: int = 4, + axis_key: int = 0, + axis_value: int = 0, + q_group_size: int = 64, + residual_length: int = 128, + ): + super().__init__() + self.nbits = nbits + self.axis_key = axis_key + self.axis_value = axis_value + self.q_group_size = q_group_size + self.residual_length = residual_length + self.cumulative_length = 0 + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Update the key and value caches in-place, and return the necessary keys and value states. + + Args: + key_states (`torch.Tensor`): The new key states to cache. + value_states (`torch.Tensor`): The new value states to cache. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. + + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states. + """ + self.cumulative_length += key_states.shape[-2] + + # Lazy initialization + if not self.is_initialized: + self.lazy_initialization(key_states) + self._quantized_keys = self._quantize(key_states.contiguous(), axis=self.axis_key) + self._quantized_values = self._quantize(value_states.contiguous(), axis=self.axis_value) + return key_states, value_states + + dequant_keys = self._dequantize(self._quantized_keys) + dequant_values = self._dequantize(self._quantized_values) + keys_to_return = torch.cat([dequant_keys, self.keys, key_states], dim=-2) + values_to_return = torch.cat([dequant_values, self.values, value_states], dim=-2) + if self.keys.dim() == 4 and self.keys.shape[-2] + 1 >= self.residual_length: + self._quantized_keys = self._quantize(keys_to_return.contiguous(), axis=self.axis_key) + self._quantized_values = self._quantize(values_to_return.contiguous(), axis=self.axis_value) + self.keys = torch.tensor([], dtype=key_states.dtype, device=key_states.device) + self.values = torch.tensor([], dtype=key_states.dtype, device=key_states.device) + else: + self.keys = torch.cat([self.keys, key_states], dim=-2) + self.values = torch.cat([self.values, value_states], dim=-2) + + return keys_to_return, values_to_return + + @abstractmethod + def _quantize(self, tensor, axis): ... + + @abstractmethod + def _dequantize(self, q_tensor): ... + + def get_seq_length(self) -> int: + """Returns the sequence length of the cached states.""" + return self.cumulative_length + + +class QuantoQuantizedLayer(QuantizedLayer): + def __init__( + self, + nbits: int = 4, + axis_key: int = 0, + axis_value: int = 0, + q_group_size: int = 64, + residual_length: int = 128, + ): + super().__init__( + nbits=nbits, + axis_key=axis_key, + axis_value=axis_value, + q_group_size=q_group_size, + residual_length=residual_length, + ) + + # We need to import quanto here to avoid circular imports due to optimum/quanto/models/transformers_models.py + if is_quanto_greater("0.2.5", accept_dev=True): + from optimum.quanto import MaxOptimizer, qint2, qint4 + else: + raise ImportError( + "You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. " + ) + + if self.nbits not in [2, 4]: + raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") + + if self.axis_key not in [0, -1]: + raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}") + + if self.axis_value not in [0, -1]: + raise ValueError( + f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" + ) + + self.qtype = qint4 if self.nbits == 4 else qint2 + self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization + + def _quantize(self, tensor, axis): + from optimum.quanto import quantize_weight + + scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) + qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) + return qtensor + + def _dequantize(self, qtensor): + return qtensor.dequantize() + + +class HQQQuantizedLayer(QuantizedLayer): + def __init__( + self, + nbits: int = 4, + axis_key: int = 0, + axis_value: int = 0, + q_group_size: int = 64, + residual_length: int = 128, + ): + super().__init__( + nbits=nbits, + axis_key=axis_key, + axis_value=axis_value, + q_group_size=q_group_size, + residual_length=residual_length, + ) + + if not is_hqq_available(): + raise ImportError("You need to install `hqq` to use `HQQQuantizedLayer`") + + if self.nbits not in [1, 2, 3, 4, 8]: + raise ValueError( + f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}" + ) + + if self.axis_key not in [0, 1]: + raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}") + + if self.axis_value not in [0, 1]: + raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}") + + self.quantizer = HQQQuantizer + + def _quantize(self, tensor, axis): + qtensor, meta = self.quantizer.quantize( + tensor, + axis=axis, + device=self.keys.device, + compute_dtype=self.keys.dtype, + nbits=self.nbits, + group_size=self.q_group_size, + ) + meta["compute_dtype"] = self.keys.dtype + self.quantizer.cuda(qtensor, meta=meta, device=self.keys.device) # Move to device and cast to dtype + meta["scale"] = meta["scale"].to(qtensor.device) + meta["zero"] = meta["zero"].to(qtensor.device) + return qtensor, meta + + def _dequantize(self, qtensor): + quant_tensor, meta = qtensor + tensor = self.quantizer.dequantize(quant_tensor, meta) + return tensor + + +class Cache: + """ + A `Cache` is mostly a list of `CacheLayerMixin` objects, one per model layer. It serves as a container for + the Cache of each layer. + + Args: + layers (`Optional`, *optional*): + A list of pre-created `CacheLayerMixin`. If omitted (`None`), then `layer_class_to_replicate` will + be used. + layer_class_to_replicate (`type[CacheLayerMixin]`, *optional*): + Only used if `layers` is omitted (`None`), in which case it will be used as the base class for each layer, + and the layers will be added lazily as soon as `update` is called with a `layer_idx` greater than the current + list of layers. + offloading (`bool`, *optional*, defaults to `False`): + Whether to perform offloading of the layers to `cpu`, to save GPU memory. + offload_only_non_sliding (`bool`, *optional*, defaults to `True`): + If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because + usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster). + """ + + def __init__( + self, + layers: Optional[list[CacheLayerMixin]] = None, + layer_class_to_replicate: Optional[type[CacheLayerMixin]] = None, + offloading: bool = False, + offload_only_non_sliding: bool = True, + ): + if layers is not None and layer_class_to_replicate is not None: + raise ValueError( + "You can construct a Cache either from a list `layers` of all the predefined `CacheLayer`, or from a " + "`layer_class_to_replicate`, in which case the Cache will append a new layer corresponding to " + "`layer_class_to_replicate` for each new call to `update` with an idx not already in the Cache." + ) + if layers is None and layer_class_to_replicate is None: + raise ValueError( + "You should provide exactly one of `layers` or `layer_class_to_replicate` to initialize a Cache." + ) + self.layers = layers if layers is not None else [] + self.layer_class_to_replicate = layer_class_to_replicate + self.offloading = offloading + if self.offloading: + self.only_non_sliding = offload_only_non_sliding + self.prefetch_stream = torch.Stream() if _is_torch_greater_or_equal_than_2_7 else torch.cuda.Stream() + + def __repr__(self): + return f"{self.__class__.__name__}(layers={self.layers})" + + def prefetch(self, layer_idx: int, only_non_sliding: bool = True): + """ + Prefetch a given layer on its device. If `only_non_sliding` is True, it will try to prefetch only the layers + which are non-sliding. If the `layer_idx` is outside the range, this will circle back to the first layers. + Note that we use a non-default stream for this, to avoid blocking. + """ + if only_non_sliding: + # Try to find next non-sliding, starting at `layer_idx` + try: + layer_idx = layer_idx + self.is_sliding[layer_idx:].index(False) + # In this case, we need to circle back to the beginning + except ValueError: + layer_idx = self.is_sliding.index(False) + else: + layer_idx = layer_idx if layer_idx < len(self.layers) else 0 + + # Prefetch + with self.prefetch_stream if _is_torch_greater_or_equal_than_2_7 else torch.cuda.stream(self.prefetch_stream): + self.layers[layer_idx].prefetch() + + def offload(self, layer_idx: int, only_non_sliding: bool = True): + """ + Offload a given `layer_idx`. If `only_non_sliding` is True, it will offload `layer_idx` only if it is a + non-sliding layer. Note that we do it on the default stream, so that we ensure all earlier + computation in the layer's `update` methods are finished. + """ + if not (only_non_sliding and self.is_sliding[layer_idx]): + self.layers[layer_idx].offload() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`dict[str, Any]`, *optional*): + Additional arguments for the cache subclass. These are specific to each subclass and allow new types of + cache to be created. + + Return: + A tuple containing the updated key and value states. + """ + # In this case, the `layers` were not provided, and we must append as much as `layer_idx` + if self.layer_class_to_replicate is not None: + while len(self.layers) <= layer_idx: + self.layers.append(self.layer_class_to_replicate()) + + if self.offloading: + # Wait for the stream to finish if needed, and start prefetching the next layer + torch.cuda.default_stream(key_states.device).wait_stream(self.prefetch_stream) + self.prefetch(layer_idx + 1, self.only_non_sliding) + + keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs) + + if self.offloading: + self.offload(layer_idx, self.only_non_sliding) + + return keys, values + + def early_initialization( + self, batch_size: int, num_heads: int, head_dim: int, dtype: torch.dtype, device: torch.device + ): + """ + Initialize all the layers in advance (it's otherwise lazily initialized on the first `update` call). + This is useful for our `export` recipes, as `export` needs everything in advance. + """ + # Note that the initialization needs all dimensions (except -2), as well as device and dtype, so we use + # this fake tensor approach. It has size 0 on the -2 dimension, so it does not allocate any data (it only + # creates an empty tensor with correct shape, dtype and device), which is very efficient and practical + fake_keys_tensor = torch.zeros((batch_size, num_heads, 0, head_dim), dtype=dtype, device=device) + # Init all layers + for layer in self.layers: + layer.lazy_initialization(fake_keys_tensor) + + def get_seq_length(self, layer_idx: int = 0) -> int: + """Returns the sequence length of the cache for the given layer.""" + if layer_idx >= len(self.layers): + return 0 + return self.layers[layer_idx].get_seq_length() + + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer. + """ + # For DynamicCache, where the layers are created at runtime -> if it was not yet created, the size is + # simply the shape of `cache_position` + if layer_idx >= len(self.layers): + return cache_position.shape[0], 0 + return self.layers[layer_idx].get_mask_sizes(cache_position) + + def get_max_cache_shape(self, layer_idx: int = 0) -> int: + """Returns maximum sequence length of the cache object. Dynamic caches do not have a maximum length.""" + # For DynamicCache, where the layers are created at runtime -> if it was not yet created, return -1 + # as DynamicLayer does + if layer_idx >= len(self.layers): + return -1 + return self.layers[layer_idx].get_max_cache_shape() + + def reset(self): + """Recursively reset all layers tensors""" + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].reset() + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorder the cache for beam search""" + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].reorder_cache(beam_idx) + + def crop(self, max_length: int): + """Crop the cache to the given length""" + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].crop(max_length) + + def batch_repeat_interleave(self, repeats: int): + """Repeat and interleave the cache""" + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].batch_repeat_interleave(repeats) + + def batch_select_indices(self, indices: torch.Tensor): + """Select indices from the cache""" + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].batch_select_indices(indices) + + @property + def max_batch_size(self) -> int: + """Return the maximum batch size of the cache""" + values = [layer.max_batch_size for layer in self.layers] + if len(set(values)) > 1: + raise ValueError(f"Max batch size is not consistent across layers: {values}") + return values[0] + + @property + def max_cache_len(self) -> int: + """Return the maximum cache length of the cache""" + values = [layer.max_cache_len for layer in self.layers] + return max(values) + + @property + def is_compileable(self) -> bool: + """Return whether the cache is compileable""" + # For DynamicCache dispatching the layers lazily (otherwise, all([]) is True) + if len(self.layers) == 0: + return False + return all(layer.is_compileable for layer in self.layers) + + @property + def is_initialized(self) -> bool: + """Return whether the cache data is initialized""" + return len(self.layers) > 0 and all(layer.is_initialized for layer in self.layers) + + @property + def is_sliding(self) -> list[bool]: + """Return whether the layers of the cache are sliding window""" + return [getattr(layer, "is_sliding", False) for layer in self.layers] + + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self.layers): + return self.layers[layer_idx].keys, self.layers[layer_idx].values + else: + raise KeyError( + f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}" + ) + + def __iter__(self): + """ + Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.layers[layer_idx].keys, self.layers[layer_idx].values) + + def __len__(self): + """ + This value corresponds to the number of layers in the model. + """ + # Note: for DynamicCache, layers are initialized lazily, so this will not be accurate before the first + # forward through all the layers + return len(self.layers) + + +class DynamicCache(Cache): + """ + A cache that grows dynamically as more tokens are generated. This is the default for generative models. + It stores the key and value states as a list of `CacheLayer`, one for each layer. The expected shape for each tensor + in the `CacheLayer`s is `[batch_size, num_heads, seq_len, head_dim]`. + If a config is passed, it will additionally check for sliding or hybrid cache structure, greatly reducing the + memory requirement of the cached tensors to `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`. + + See `Cache` for details on common methods that are implemented by all cache classes. + + Args: + ddp_cache_data (`Iterable[tuple[torch.Tensor, torch.Tensor]]`, *optional*): + It was originally added for compatibility with `torch.distributed` (DDP). In a nutshell, it is + `map(gather_map, zip(*caches))`, i.e. each item in the iterable contains the key and value states + for a layer gathered across replicas by torch.distributed (shape=[global batch size, num_heads, seq_len, head_dim]). + Note: it needs to be the 1st arg as well to work correctly + config (`PretrainedConfig`, *optional*): + The config of the model for which this Cache will be used. If passed, it will be used to check for sliding + or hybrid layer structure, greatly reducing the memory requirement of the cached tensors to + `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`. + offloading (`bool`, *optional*, defaults to `False`): + Whether to perform offloading of the layers to `cpu`, to save GPU memory. + offload_only_non_sliding (`bool`, *optional*, defaults to `False`): + If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because + usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster). + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = DynamicCache(config=model.config) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + ``` + """ + + def __init__( + self, + ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, + config: Optional[PretrainedConfig] = None, + offloading: bool = False, + offload_only_non_sliding: bool = False, + ): + layers = [] + # If a config is passed, use it to infer the layer types and initialize accordingly + if config is not None: + decoder_config = config.get_text_config(decoder=True) + sliding_window = getattr(decoder_config, "sliding_window", None) or getattr( + decoder_config, "attention_chunk_size", None + ) + layer_types = getattr(decoder_config, "layer_types", None) + if layer_types is None: + layer_types = [ + "sliding_attention" if sliding_window is not None else "full_attention" + for _ in range(decoder_config.num_hidden_layers) + ] + # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n) + if hasattr(decoder_config, "num_kv_shared_layers"): + layer_types = layer_types[: -decoder_config.num_kv_shared_layers] + + for layer_type in layer_types: + # From a cache point of view, both sliding and chunked are the same in how they should behave and how many + # states they should return - only the mask changes to make them different at the end! + if layer_type in ("sliding_attention", "chunked_attention"): + layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window)) + else: + layers.append(DynamicLayer()) + + # In this case, use the passed data to already fill in the Cache + if ddp_cache_data is not None: + # Init all the layers with the data + for layer_idx, (key_states, value_states) in enumerate(ddp_cache_data): + # If the config was not passed above, initialize a DynamicLayer for each entry of the ddp_data + if config is None: + layers.append(DynamicLayer()) + # Update the layer with the data + _, _ = layers[layer_idx].update(key_states, value_states) + + # If neither of config nor ddp_data was passed, then simply lazy init a full cache of DynamicLayer + if len(layers) == 0: + super().__init__( + layer_class_to_replicate=DynamicLayer, + offloading=offloading, + offload_only_non_sliding=offload_only_non_sliding, + ) + else: + super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding) + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]: + """ + Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility. + """ + legacy_cache = () + for layer in self.layers: + legacy_cache += ((layer.keys, layer.values),) + return legacy_cache + + @classmethod + def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.Tensor, torch.Tensor]]) -> "DynamicCache": + """ + Converts a cache in the legacy cache format into an equivalent `Cache`. Used for + backward compatibility. + """ + cache = cls() + if past_key_values is None: + logger.warning_once("past_key_values should not be None in from_legacy_cache()") + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + +class StaticCache(Cache): + """ + Static Cache class to be used with `torch.compile(model)` and `torch.export()`. It will check the `config` + for potential hybrid cache structure, and initialize each layer accordingly. + + See `Cache` for details on common methods that are implemented by all cache classes. + + Args: + config (`PretrainedConfig`): + The config of the model for which this Cache will be used. It will be used to check for sliding + or hybrid layer structure, and initialize each layer accordingly. + max_cache_len (`int`): + The maximum number of tokens that this Cache should hold. + offloading (`bool`, *optional*, defaults to `False`): + Whether to perform offloading of the layers to `cpu`, to save GPU memory. + offload_only_non_sliding (`bool`, *optional*, defaults to `True`): + If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because + usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster). + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache + + >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + + >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = StaticCache(config=model.config, max_cache_len=max_generated_length) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + StaticCache() + ``` + """ + + # Pass-in kwargs as well to avoid crashing for BC (it used more arguments before) + def __init__( + self, + config: PretrainedConfig, + max_cache_len: int, + offloading: bool = False, + offload_only_non_sliding: bool = True, + **kwargs, + ): + config = config.get_text_config(decoder=True) + layer_types = getattr(config, "layer_types", None) + # If `layer_types` is not explicitly provided, infer if the model is fully sliding + if layer_types is None: + if getattr(config, "sliding_window", None) is not None: + layer_types = ["sliding_attention" for _ in range(config.num_hidden_layers)] + elif getattr(config, "attention_chunk_size", None) is not None: + layer_types = ["chunked_attention" for _ in range(config.num_hidden_layers)] + else: + layer_types = ["full_attention" for _ in range(config.num_hidden_layers)] + # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n) + if hasattr(config, "num_kv_shared_layers"): + layer_types = layer_types[: -config.num_kv_shared_layers] + + layers = [] + for layer_type in layer_types: + if layer_type == "sliding_attention": + layer = StaticSlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.sliding_window) + elif layer_type == "chunked_attention": + # From a cache point of view, both sliding and chunked are the same in how they should behave and how many + # states they should return - only the mask changes to make them different at the end! + layer = StaticSlidingWindowLayer( + max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size + ) + else: + layer = StaticLayer(max_cache_len=max_cache_len) + layers.append(layer) + + super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding) + + +class QuantizedCache(Cache): + """ + A quantizer cache similar to what is described in the + [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for keys and values + by applying quantization. + The cache has two types of storage, one for original precision and one for the + quantized cache. A `residual length` is set as a maximum capacity for the original precision cache. When the + length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. + The quantization is done per-channel with a set `q_group_size` for both keys and values, in contrast to what was + described in the paper. + + See `Cache` for details on common methods that are implemented by all cache classes. + + Args: + backend (`str`): + The quantization backend to use. One of `("quanto", "hqq"). + config (`PretrainedConfig`): + The config of the model for which this Cache will be used. + nbits (`int`, *optional*, defaults to 4): + The number of bits for quantization. + axis_key (`int`, *optional*, defaults to 0): + The axis on which to quantize the keys. + axis_value (`int`, *optional*, defaults to 0): + The axis on which to quantize the values. + q_group_size (`int`, *optional*, defaults to 64): + Quantization is done per-channel according to a set `q_group_size` for both keys and values. + residual_length (`int`, *optional*, defaults to 128): + Maximum capacity for the original precision cache + """ + + def __init__( + self, + backend: str, + config: PretrainedConfig, + nbits: int = 4, + axis_key: int = 0, + axis_value: int = 0, + q_group_size: int = 64, + residual_length: int = 128, + ): + if backend == "quanto": + layer_class = QuantoQuantizedLayer + elif backend == "hqq": + layer_class = HQQQuantizedLayer + else: + raise ValueError(f"Unknown quantization backend `{backend}`") + + config = config.get_text_config(decoder=True) + layers = [ + layer_class(nbits, axis_key, axis_value, q_group_size, residual_length) + for _ in range(config.num_hidden_layers) + ] + super().__init__(layers=layers) + + +class EncoderDecoderCache(Cache): + """ + Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and + cross-attention caches. + + See `Cache` for details on common methods that are implemented by all cache classes. + + Args: + caches (`Iterable`): + Usually an iterable of length 2, containing 2 `Cache` objects, the first one for self-attention, the + second one for cross-attention. Can optionally also be an iterable of length 1, containing a + `tuple[tuple[torch.Tensor]]` (usually used for compatibility with torch dp and ddp). + + Example: + + ```python + >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small") + >>> processor = AutoProcessor.from_pretrained("openai/whisper-small") + + >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt") + + >>> # Prepare cache classes for encoder and decoder and pass it to model's forward + >>> self_attention_cache = DynamicCache(config=self.config) + >>> cross_attention_cache = DynamicCache(config=self.config) + >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + EncoderDecoderCache() + ``` + """ + + def __init__(self, *caches) -> None: + # For dp and ddp support, if only one argument is passed, it should be an iterable of tuples of tensors + if len(caches) == 1: + self.self_attention_cache = DynamicCache() + self.cross_attention_cache = DynamicCache() + # Populate cache from the iterable + for layer_idx, key_value_states in enumerate(caches[0]): + key_states, value_states = key_value_states[:2] + self.self_attention_cache.update(key_states, value_states, layer_idx) + if len(key_value_states) > 2: + key_states, value_states = key_value_states[2:] + self.cross_attention_cache.update(key_states, value_states, layer_idx) + # Otherwise, we should get two arguments, a self-attention cache and a cross-attention cache + elif len(caches) == 2: + if not isinstance(caches[0], Cache) or not isinstance(caches[1], Cache): + raise TypeError(f"One of the two arguments is not a Cache: {type(caches[0]) = }, {type(caches[1]) = }") + self.self_attention_cache = caches[0] + self.cross_attention_cache = caches[1] + # Error case + else: + raise ValueError(f"Expected 1 or 2 arguments, got {len(caches)}") + + self.is_updated = {} + for layer_idx in range(len(self.cross_attention_cache)): + self.is_updated[layer_idx] = bool(self.cross_attention_cache.get_seq_length(layer_idx) > 0) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(self_attention_cache={self.self_attention_cache}, cross_attention_cache=" + f"{self.cross_attention_cache})" + ) + + def __iter__(self): + """ + Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield ( + self.self_attention_cache.layers[layer_idx].keys, + self.self_attention_cache.layers[layer_idx].values, + self.cross_attention_cache.layers[layer_idx].keys, + self.cross_attention_cache.layers[layer_idx].values, + ) + + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return ( + self.self_attention_cache.layers[layer_idx].keys, + self.self_attention_cache.layers[layer_idx].values, + self.cross_attention_cache.layers[layer_idx].keys, + self.cross_attention_cache.layers[layer_idx].values, + ) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __len__(self): + """ + Support for backwards-compatible `past_key_values` length, e.g. `len(past_key_values)`. This value corresponds + to the number of layers in the model. + """ + return len(self.self_attention_cache) + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor]]: + """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format.""" + legacy_cache = () + if len(self.cross_attention_cache) > 0: + for self_attn, cross_attn in zip( + self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache() + ): + legacy_cache += (self_attn + cross_attn,) + else: + legacy_cache = self.self_attention_cache.to_legacy_cache() + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, past_key_values: Optional[Iterable[tuple[torch.FloatTensor, ...]]] + ) -> "EncoderDecoderCache": + """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" + cache = cls(DynamicCache(), DynamicCache()) + if past_key_values is None: + logger.warning_once("past_key_values should not be None in from_legacy_cache()") + else: + for layer_idx, key_value_states in enumerate(past_key_values): + key_states, value_states = key_value_states[:2] + cache.self_attention_cache.update(key_states, value_states, layer_idx) + if len(key_value_states) > 2: + key_states, value_states = key_value_states[2:] + cache.cross_attention_cache.update(key_states, value_states, layer_idx) + cache.is_updated[layer_idx] = True + return cache + + def get_seq_length(self, layer_idx: int = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + return self.self_attention_cache.get_seq_length(layer_idx) + + def reset(self): + self.self_attention_cache.reset() + self.cross_attention_cache.reset() + for layer_idx in self.is_updated: + self.is_updated[layer_idx] = False + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + self.self_attention_cache.reorder_cache(beam_idx) + self.cross_attention_cache.reorder_cache(beam_idx) + + def check_dynamic_cache(self, method: str): + if not ( + isinstance(self.self_attention_cache, DynamicCache) + and isinstance(self.cross_attention_cache, DynamicCache) + ): + raise ValueError( + f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self " + f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache." + ) + + # TODO(gante, sanchit-gandhi): move following functionality into `.generate` + def crop(self, maximum_length: int): + """ + Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be + negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search (on the Hub). + """ + self.check_dynamic_cache(self.crop.__name__) + self.self_attention_cache.crop(maximum_length) + + def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]": + """ + Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils` + """ + self.check_dynamic_cache(self.batch_split.__name__) + self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) + cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) + + out = [] + for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache): + out.append(EncoderDecoderCache(self_attn, cross_attn)) + return out + + def batch_repeat_interleave(self, repeats: int): + """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search (on the Hub).""" + self.check_dynamic_cache(self.batch_repeat_interleave.__name__) + self.self_attention_cache.batch_repeat_interleave(repeats) + self.cross_attention_cache.batch_repeat_interleave(repeats) + + def batch_select_indices(self, indices: torch.Tensor): + """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search (on the Hub).""" + self.check_dynamic_cache(self.batch_select_indices.__name__) + self.self_attention_cache.batch_select_indices(indices) + self.cross_attention_cache.batch_select_indices(indices) + + def get_max_cache_shape(self) -> int: + """Returns the maximum sequence length (i.e. max capacity) of the cache object""" + return self.self_attention_cache.get_max_cache_shape() + + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx) + + @property + def is_sliding(self): + return self.self_attention_cache.is_sliding + + @property + def is_compileable(self) -> bool: + return self.self_attention_cache.is_compileable + + +### Deprecated classes + + +class SlidingWindowLayer(StaticSlidingWindowLayer): + def __init__(self, max_cache_len: int, sliding_window: int): + logger.warning_once( + "`SlidingWindowLayer` is deprecated and will be removed in version v4.59 " + "Use `StaticSlidingWindowLayer` instead, which is a better name for it." + ) + super().__init__(max_cache_len, sliding_window) + + +class ChunkedSlidingLayer(StaticSlidingWindowLayer): + def __init__(self, max_cache_len: int, sliding_window: int): + logger.warning_once( + "`ChunkedSlidingLayer` is deprecated and will be removed in version v4.59 " + "Use `StaticSlidingWindowLayer` instead, which has the exact same functionalities." + ) + super().__init__(max_cache_len, sliding_window) + + +class OffloadedCache(DynamicCache): + def __init__(self) -> None: + logger.warning_once( + "`OffloadedCache` is deprecated and will be removed in version v4.59 " + "Use `DynamicCache(offloading=True)` instead" + ) + super().__init__(offloading=True) + + +class OffloadedStaticCache(StaticCache): + def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs): + logger.warning_once( + "`OffloadedStaticCache` is deprecated and will be removed in version v4.59 " + "Use `StaticCache(..., offloading=True)` instead" + ) + super().__init__(config=config, max_cache_len=max_cache_len, offloading=True) + + +class SlidingWindowCache(StaticCache): + def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs): + logger.warning_once( + "`SlidingWindowCache` is deprecated and will be removed in version v4.59 " + "Use `StaticCache(...)` instead which will correctly infer the type of each layer." + ) + super().__init__(config=config, max_cache_len=max_cache_len) + + +class HybridCache(StaticCache): + def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs): + logger.warning_once( + "`HybridCache` is deprecated and will be removed in version v4.59 " + "Use `StaticCache(...)` instead which will correctly infer the type of each layer." + ) + super().__init__(config=config, max_cache_len=max_cache_len) + + +class HybridChunkedCache(StaticCache): + def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs): + logger.warning_once( + "`HybridChunkedCache` is deprecated and will be removed in version v4.59 " + "Use `StaticCache(...)` instead which will correctly infer the type of each layer." + ) + super().__init__(config=config, max_cache_len=max_cache_len) + + +class OffloadedHybridCache(StaticCache): + def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs): + logger.warning_once( + "`OffloadedHybridCache` is deprecated and will be removed in version v4.59 " + "Use `StaticCache(..., offload=True)` instead which will correctly infer the type of each layer." + ) + super().__init__(config=config, max_cache_len=max_cache_len, offloading=True) + + +class QuantoQuantizedCache(QuantizedCache): + def __init__( + self, + config: PretrainedConfig, + nbits: int = 4, + axis_key: int = 0, + axis_value: int = 0, + q_group_size: int = 64, + residual_length: int = 128, + ): + logger.warning_once( + "`QuantoQuantizedCache` is deprecated and will be removed in version v4.59 " + "Use `QuantizedCache(backend='quanto', ...)` instead." + ) + super().__init__("quanto", config, nbits, axis_key, axis_value, q_group_size, residual_length) + + +class HQQQuantizedCache(QuantizedCache): + def __init__( + self, + config: PretrainedConfig, + nbits: int = 4, + axis_key: int = 0, + axis_value: int = 0, + q_group_size: int = 64, + residual_length: int = 128, + ): + logger.warning_once( + "`HQQQuantizedCache` is deprecated and will be removed in version v4.59 " + "Use `QuantizedCache(backend='hqq', ...)` instead." + ) + super().__init__("hqq", config, nbits, axis_key, axis_value, q_group_size, residual_length) + + +class SinkCache(Cache): + """ + It is now a `custom_generate` repository on the Hub: https://huggingface.co/transformers-community/sink_cache. + See [these docs](https://huggingface.co/docs/transformers/generation_strategies#custom-decoding-methods) for + general `custom_generate`usage. + """ + + # TODO (joao, manuel): Remove this class in v4.59.0 + def __init__(self, **kwargs) -> None: + raise NotImplementedError( + "`SinkCache` has been moved as a `custom_generate` repository on the Hub: " + "https://huggingface.co/transformers-community/sink_cache. See the repository for usage examples." + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..e2c825a45b605c271c3912bf5b4f3ff4b5c1a32c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py @@ -0,0 +1,86 @@ +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Seq2Seq TF Hub checkpoint.""" + +import argparse + +from . import ( + BertConfig, + BertGenerationConfig, + BertGenerationDecoder, + BertGenerationEncoder, + load_tf_weights_in_bert_generation, + logging, +) + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_hub_path, pytorch_dump_path, is_encoder_named_decoder, vocab_size, is_encoder): + # Initialise PyTorch model + bert_config = BertConfig.from_pretrained( + "google-bert/bert-large-cased", + vocab_size=vocab_size, + max_position_embeddings=512, + is_decoder=True, + add_cross_attention=True, + ) + bert_config_dict = bert_config.to_dict() + del bert_config_dict["type_vocab_size"] + config = BertGenerationConfig(**bert_config_dict) + if is_encoder: + model = BertGenerationEncoder(config) + else: + model = BertGenerationDecoder(config) + print(f"Building PyTorch model from configuration: {config}") + + # Load weights from tf checkpoint + load_tf_weights_in_bert_generation( + model, + tf_hub_path, + model_class="bert", + is_encoder_named_decoder=is_encoder_named_decoder, + is_encoder=is_encoder, + ) + + # Save pytorch-model + print(f"Save PyTorch model and config to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_hub_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--is_encoder_named_decoder", + action="store_true", + help="If decoder has to be renamed to encoder in PyTorch model.", + ) + parser.add_argument("--is_encoder", action="store_true", help="If model is an encoder.") + parser.add_argument("--vocab_size", default=50358, type=int, help="Vocab size of model") + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch( + args.tf_hub_path, + args.pytorch_dump_path, + args.is_encoder_named_decoder, + args.vocab_size, + is_encoder=args.is_encoder, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/debug_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/debug_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..920b1cf44dafd320729eef1e6a36b9a41741c83c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/debug_utils.py @@ -0,0 +1,346 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections + +from .utils import ExplicitEnum, is_torch_available, logging + + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +class DebugUnderflowOverflow: + """ + This debug class helps detect and understand where the model starts getting very large or very small, and more + importantly `nan` or `inf` weight and activation elements. + + There are 2 working modes: + + 1. Underflow/overflow detection (default) + 2. Specific batch absolute min/max tracing without detection + + Mode 1: Underflow/overflow detection + + To activate the underflow/overflow detection, initialize the object with the model : + + ```python + debug_overflow = DebugUnderflowOverflow(model) + ``` + + then run the training as normal and if `nan` or `inf` gets detected in at least one of the weight, input or output + elements this module will throw an exception and will print `max_frames_to_save` frames that lead to this event, + each frame reporting + + 1. the fully qualified module name plus the class name whose `forward` was run + 2. the absolute min and max value of all elements for each module weights, and the inputs and output + + For example, here is the header and the last few frames in detection report for `google/mt5-small` run in fp16 + mixed precision : + + ``` + Detected inf/nan during batch_number=0 + Last 21 forward frames: + abs min abs max metadata + [...] + encoder.block.2.layer.1.DenseReluDense.wi_0 Linear + 2.17e-07 4.50e+00 weight + 1.79e-06 4.65e+00 input[0] + 2.68e-06 3.70e+01 output + encoder.block.2.layer.1.DenseReluDense.wi_1 Linear + 8.08e-07 2.66e+01 weight + 1.79e-06 4.65e+00 input[0] + 1.27e-04 2.37e+02 output + encoder.block.2.layer.1.DenseReluDense.wo Linear + 1.01e-06 6.44e+00 weight + 0.00e+00 9.74e+03 input[0] + 3.18e-04 6.27e+04 output + encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense + 1.79e-06 4.65e+00 input[0] + 3.18e-04 6.27e+04 output + encoder.block.2.layer.1.dropout Dropout + 3.18e-04 6.27e+04 input[0] + 0.00e+00 inf output + ``` + + You can see here, that `T5DenseGatedGeluDense.forward` resulted in output activations, whose absolute max value was + around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have `Dropout` which + renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than + 64K, and we get an overflow. + + As you can see it's the previous frames that we need to look into when the numbers start going into very large for + fp16 numbers. + + The tracking is done in a forward hook, which gets invoked immediately after `forward` has completed. + + By default the last 21 frames are printed. You can change the default to adjust for your needs. For example : + + ```python + debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100) + ``` + + To validate that you have set up this debugging feature correctly, and you intend to use it in a training that + may take hours to complete, first run it with normal tracing enabled for one of a few batches as explained in + the next section. + + + Mode 2. Specific batch absolute min/max tracing without detection + + The second work mode is per-batch tracing with the underflow/overflow detection feature turned off. + + Let's say you want to watch the absolute min and max values for all the ingredients of each `forward` call of a + given batch, and only do that for batches 1 and 3. Then you instantiate this class as : + + ```python + debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3]) + ``` + + And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed. + + This is helpful if you know that the program starts misbehaving after a certain batch number, so you can + fast-forward right to that area. + + + Early stopping: + + You can also specify the batch number after which to stop the training, with : + + ```python + debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3], abort_after_batch_num=3) + ``` + + This feature is mainly useful in the tracing mode, but you can use it for any mode. + + + **Performance**: + + As this module measures absolute `min`/``max` of each weight of the model on every forward it'll slow the training + down. Therefore remember to turn it off once the debugging needs have been met. + + Args: + model (`nn.Module`): + The model to debug. + max_frames_to_save (`int`, *optional*, defaults to 21): + How many frames back to record + trace_batch_nums(`list[int]`, *optional*, defaults to `[]`): + Which batch numbers to trace (turns detection off) + abort_after_batch_num (`int``, *optional*): + Whether to abort after a certain batch number has finished + """ + + def __init__(self, model, max_frames_to_save=21, trace_batch_nums=[], abort_after_batch_num=None): + self.model = model + self.trace_batch_nums = trace_batch_nums + self.abort_after_batch_num = abort_after_batch_num + + # keep a LIFO buffer of frames to dump as soon as inf/nan is encountered to give context to the problem emergence + self.frames = collections.deque([], max_frames_to_save) + self.frame = [] + self.batch_number = 0 + self.total_calls = 0 + self.detected_overflow = False + self.prefix = " " + + self.analyse_model() + + self.register_forward_hook() + + def save_frame(self, frame=None): + if frame is not None: + self.expand_frame(frame) + self.frames.append("\n".join(self.frame)) + self.frame = [] # start a new frame + + def expand_frame(self, line): + self.frame.append(line) + + def trace_frames(self): + print("\n".join(self.frames)) + self.frames = [] + + def reset_saved_frames(self): + self.frames = [] + + def dump_saved_frames(self): + print(f"\nDetected inf/nan during batch_number={self.batch_number}") + print(f"Last {len(self.frames)} forward frames:") + print(f"{'abs min':8} {'abs max':8} metadata") + print("\n".join(self.frames)) + print("\n\n") + self.frames = [] + + def analyse_model(self): + # extract the fully qualified module names, to be able to report at run time. e.g.: + # encoder.block.2.layer.0.SelfAttention.o + # + # for shared weights only the first shared module name will be registered + self.module_names = {m: name for name, m in self.model.named_modules()} + # self.longest_module_name = max(len(v) for v in self.module_names.values()) + + def analyse_variable(self, var, ctx): + if torch.is_tensor(var): + self.expand_frame(get_abs_min_max(var, ctx)) + if detect_overflow(var, ctx): + self.detected_overflow = True + elif var is None: + self.expand_frame(f"{'None':>17} {ctx}") + else: + self.expand_frame(f"{'not a tensor':>17} {ctx}") + + def batch_start_frame(self): + self.expand_frame(f"\n\n{self.prefix} *** Starting batch number={self.batch_number} ***") + self.expand_frame(f"{'abs min':8} {'abs max':8} metadata") + + def batch_end_frame(self): + self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number - 1} ***\n\n") + + def create_frame(self, module, input, output): + self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}") + + # params + for name, p in module.named_parameters(recurse=False): + self.analyse_variable(p, name) + + # inputs + if isinstance(input, tuple): + for i, x in enumerate(input): + self.analyse_variable(x, f"input[{i}]") + else: + self.analyse_variable(input, "input") + + # outputs + if isinstance(output, tuple): + for i, x in enumerate(output): + # possibly a tuple of tuples + if isinstance(x, tuple): + for j, y in enumerate(x): + self.analyse_variable(y, f"output[{i}][{j}]") + else: + self.analyse_variable(x, f"output[{i}]") + else: + self.analyse_variable(output, "output") + + self.save_frame() + + def register_forward_hook(self): + self.model.apply(self._register_forward_hook) + + def _register_forward_hook(self, module): + module.register_forward_hook(self.forward_hook) + + def forward_hook(self, module, input, output): + # - input is a tuple of packed inputs (could be non-Tensors) + # - output could be a Tensor or a tuple of Tensors and non-Tensors + + last_frame_of_batch = False + + trace_mode = self.batch_number in self.trace_batch_nums + if trace_mode: + self.reset_saved_frames() + + if self.total_calls == 0: + self.batch_start_frame() + self.total_calls += 1 + + # count batch numbers - the very first forward hook of the batch will be called when the + # batch completes - i.e. it gets called very last - we know this batch has finished + if module == self.model: + self.batch_number += 1 + last_frame_of_batch = True + + self.create_frame(module, input, output) + + # if last_frame_of_batch: + # self.batch_end_frame() + + if trace_mode: + self.trace_frames() + + if last_frame_of_batch: + self.batch_start_frame() + + if self.detected_overflow and not trace_mode: + self.dump_saved_frames() + + # now we can abort, as it's pointless to continue running + raise ValueError( + "DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. " + "Please scroll up above this traceback to see the activation values prior to this event." + ) + + # abort after certain batch if requested to do so + if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num: + raise ValueError( + f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to" + f" `abort_after_batch_num={self.abort_after_batch_num}` arg" + ) + + +def get_abs_min_max(var, ctx): + abs_var = var.abs() + return f"{abs_var.min():8.2e} {abs_var.max():8.2e} {ctx}" + + +def detect_overflow(var, ctx): + """ + Report whether the tensor contains any `nan` or `inf` entries. + + This is useful for detecting overflows/underflows and best to call right after the function that did some math that + modified the tensor in question. + + This function contains a few other helper features that you can enable and tweak directly if you want to track + various other things. + + Args: + var: the tensor variable to check + ctx: the message to print as a context + + Return: + `True` if `inf` or `nan` was detected, `False` otherwise + """ + detected = False + if torch.isnan(var).any().item(): + detected = True + print(f"{ctx} has nans") + if torch.isinf(var).any().item(): + detected = True + print(f"{ctx} has infs") + + # if needed to monitor large elements can enable the following + if 0: # and detected: + n100 = var[torch.ge(var.abs(), 100)] + if n100.numel() > 0: + print(f"{ctx}: n100={n100.numel()}") + n1000 = var[torch.ge(var.abs(), 1000)] + if n1000.numel() > 0: + print(f"{ctx}: n1000={n1000.numel()}") + n10000 = var[torch.ge(var.abs(), 10000)] + if n10000.numel() > 0: + print(f"{ctx}: n10000={n10000.numel()}") + + if 0: + print(f"min={var.min():9.2e} max={var.max():9.2e}") + + if 0: + print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})") + + return detected + + +class DebugOption(ExplicitEnum): + UNDERFLOW_OVERFLOW = "underflow_overflow" + TPU_METRICS_DEBUG = "tpu_metrics_debug" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/dependency_versions_table.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/dependency_versions_table.py new file mode 100644 index 0000000000000000000000000000000000000000..42bbcbaabfad340b2e08abe722dbdfe50639e40c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/dependency_versions_table.py @@ -0,0 +1,108 @@ +# THIS FILE HAS BEEN AUTOGENERATED. To update: +# 1. modify the `_deps` dict in setup.py +# 2. run `make deps_table_update`` +deps = { + "Pillow": "Pillow>=10.0.1,<=15.0", + "accelerate": "accelerate>=0.26.0", + "av": "av", + "beautifulsoup4": "beautifulsoup4", + "blobfile": "blobfile", + "codecarbon": "codecarbon>=2.8.1", + "cookiecutter": "cookiecutter==1.7.3", + "dataclasses": "dataclasses", + "datasets": "datasets>=2.15.0", + "deepspeed": "deepspeed>=0.9.3", + "diffusers": "diffusers", + "dill": "dill<0.3.5", + "evaluate": "evaluate>=0.2.0", + "faiss-cpu": "faiss-cpu", + "fastapi": "fastapi", + "filelock": "filelock", + "flax": "flax>=0.4.1,<=0.7.0", + "ftfy": "ftfy", + "fugashi": "fugashi>=1.0", + "GitPython": "GitPython<3.1.19", + "hf-doc-builder": "hf-doc-builder>=0.3.0", + "hf_xet": "hf_xet", + "huggingface-hub": "huggingface-hub>=0.34.0,<1.0", + "importlib_metadata": "importlib_metadata", + "ipadic": "ipadic>=1.0.0,<2.0", + "jax": "jax>=0.4.1,<=0.4.13", + "jaxlib": "jaxlib>=0.4.1,<=0.4.13", + "jinja2": "jinja2>=3.1.0", + "kenlm": "kenlm", + "keras": "keras>2.9,<2.16", + "keras-nlp": "keras-nlp>=0.3.1,<0.14.0", + "kernels": "kernels>=0.6.1,<=0.9", + "librosa": "librosa", + "natten": "natten>=0.14.6,<0.15.0", + "nltk": "nltk<=3.8.1", + "num2words": "num2words", + "numpy": "numpy>=1.17", + "onnxconverter-common": "onnxconverter-common", + "onnxruntime-tools": "onnxruntime-tools>=1.4.2", + "onnxruntime": "onnxruntime>=1.4.0", + "openai": "openai>=1.98.0", + "opencv-python": "opencv-python", + "optimum-benchmark": "optimum-benchmark>=0.3.0", + "optuna": "optuna", + "optax": "optax>=0.0.8,<=0.1.4", + "pandas": "pandas<2.3.0", + "packaging": "packaging>=20.0", + "parameterized": "parameterized>=0.9", + "phonemizer": "phonemizer", + "protobuf": "protobuf", + "psutil": "psutil", + "pyyaml": "pyyaml>=5.1", + "pydantic": "pydantic>=2", + "pytest": "pytest>=7.2.0", + "pytest-asyncio": "pytest-asyncio", + "pytest-rerunfailures": "pytest-rerunfailures<16.0", + "pytest-timeout": "pytest-timeout", + "pytest-xdist": "pytest-xdist", + "pytest-order": "pytest-order", + "python": "python>=3.9.0", + "ray[tune]": "ray[tune]>=2.7.0", + "regex": "regex!=2019.12.17", + "requests": "requests", + "rhoknp": "rhoknp>=1.1.0,<1.3.1", + "rjieba": "rjieba", + "rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", + "ruff": "ruff==0.13.1", + "sacrebleu": "sacrebleu>=1.4.12,<2.0.0", + "sacremoses": "sacremoses", + "safetensors": "safetensors>=0.4.3", + "sagemaker": "sagemaker>=2.31.0", + "schedulefree": "schedulefree>=1.2.6", + "scikit-learn": "scikit-learn", + "scipy": "scipy<1.13.0", + "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", + "sigopt": "sigopt", + "starlette": "starlette", + "sudachipy": "sudachipy>=0.6.6", + "sudachidict_core": "sudachidict_core>=20220729", + "tensorboard": "tensorboard", + "tensorflow-cpu": "tensorflow-cpu>2.9,<2.16", + "tensorflow": "tensorflow>2.9,<2.16", + "tensorflow-text": "tensorflow-text<2.16", + "tensorflow-probability": "tensorflow-probability<0.24", + "tf2onnx": "tf2onnx", + "timeout-decorator": "timeout-decorator", + "tiktoken": "tiktoken", + "timm": "timm<=1.0.19,!=1.0.18", + "tokenizers": "tokenizers>=0.22.0,<=0.23.0", + "torch": "torch>=2.2", + "torchaudio": "torchaudio", + "torchvision": "torchvision", + "pyctcdecode": "pyctcdecode>=0.4.0", + "tqdm": "tqdm>=4.27", + "unidic": "unidic>=1.0.2", + "unidic_lite": "unidic_lite>=1.0.7", + "urllib3": "urllib3<2.0.0", + "uvicorn": "uvicorn", + "pytest-rich": "pytest-rich", + "libcst": "libcst", + "rich": "rich", + "opentelemetry-api": "opentelemetry-api", + "mistral-common[opencv]": "mistral-common[opencv]>=1.6.3", +} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/dynamic_module_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/dynamic_module_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d4e2bf48921ffdb5f94c02f966225d97efde05f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/dynamic_module_utils.py @@ -0,0 +1,843 @@ +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities to dynamically load objects from the Hub.""" + +import ast +import filecmp +import hashlib +import importlib +import importlib.metadata +import importlib.util +import keyword +import os +import re +import shutil +import signal +import sys +import threading +import warnings +from pathlib import Path +from types import ModuleType +from typing import Any, Optional, Union + +from huggingface_hub import try_to_load_from_cache +from packaging import version + +from .utils import ( + HF_MODULES_CACHE, + TRANSFORMERS_DYNAMIC_MODULE_NAME, + cached_file, + extract_commit_hash, + is_offline_mode, + logging, +) +from .utils.import_utils import VersionComparison, split_package_version + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _sanitize_module_name(name: str) -> str: + r""" + Tries to sanitize a module name so that it can be used as a Python module. + + The following transformations are applied: + + 1. Replace `.` in module names with `_dot_`. + 2. Replace `-` in module names with `_hyphen_`. + 3. If the module name starts with a digit, prepend it with `_`. + 4. Warn if the sanitized name is a Python reserved keyword or not a valid identifier. + + If the input name is already a valid identifier, it is returned unchanged. + """ + # We not replacing `\W` characters with `_` to avoid collisions. Because `_` is a very common + # separator used in module names, replacing `\W` with `_` would create too many collisions. + # Once a module is imported, it is cached in `sys.modules` and the second import would return + # the first module, which might not be the expected behavior if name collisions happen. + new_name = name.replace(".", "_dot_").replace("-", "_hyphen_") + if new_name and new_name[0].isdigit(): + new_name = f"_{new_name}" + if keyword.iskeyword(new_name): + logger.warning( + f"The module name {new_name} (originally {name}) is a reserved keyword in Python. " + "Please rename the original module to avoid import issues." + ) + elif not new_name.isidentifier(): + logger.warning( + f"The module name {new_name} (originally {name}) is not a valid Python identifier. " + "Please rename the original module to avoid import issues." + ) + return new_name + + +_HF_REMOTE_CODE_LOCK = threading.Lock() + + +def init_hf_modules(): + """ + Creates the cache directory for modules with an init, and adds it to the Python path. + """ + # This function has already been executed if HF_MODULES_CACHE already is in the Python path. + if HF_MODULES_CACHE in sys.path: + return + + sys.path.append(HF_MODULES_CACHE) + os.makedirs(HF_MODULES_CACHE, exist_ok=True) + init_path = Path(HF_MODULES_CACHE) / "__init__.py" + if not init_path.exists(): + init_path.touch() + importlib.invalidate_caches() + + +def create_dynamic_module(name: Union[str, os.PathLike]) -> None: + """ + Creates a dynamic module in the cache directory for modules. + + Args: + name (`str` or `os.PathLike`): + The name of the dynamic module to create. + """ + init_hf_modules() + dynamic_module_path = (Path(HF_MODULES_CACHE) / name).resolve() + # If the parent module does not exist yet, recursively create it. + if not dynamic_module_path.parent.exists(): + create_dynamic_module(dynamic_module_path.parent) + os.makedirs(dynamic_module_path, exist_ok=True) + init_path = dynamic_module_path / "__init__.py" + if not init_path.exists(): + init_path.touch() + # It is extremely important to invalidate the cache when we change stuff in those modules, or users end up + # with errors about module that do not exist. Same for all other `invalidate_caches` in this file. + importlib.invalidate_caches() + + +def get_relative_imports(module_file: Union[str, os.PathLike]) -> list[str]: + """ + Get the list of modules that are relatively imported in a module file. + + Args: + module_file (`str` or `os.PathLike`): The module file to inspect. + + Returns: + `list[str]`: The list of relative imports in the module. + """ + with open(module_file, encoding="utf-8") as f: + content = f.read() + + # Imports of the form `import .xxx` + relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE) + # Imports of the form `from .xxx import yyy` + relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE) + # Unique-ify + return list(set(relative_imports)) + + +def get_relative_import_files(module_file: Union[str, os.PathLike]) -> list[str]: + """ + Get the list of all files that are needed for a given module. Note that this function recurses through the relative + imports (if a imports b and b imports c, it will return module files for b and c). + + Args: + module_file (`str` or `os.PathLike`): The module file to inspect. + + Returns: + `list[str]`: The list of all relative imports a given module needs (recursively), which will give us the list + of module files a given module needs. + """ + no_change = False + files_to_check = [module_file] + all_relative_imports = [] + + # Let's recurse through all relative imports + while not no_change: + new_imports = [] + for f in files_to_check: + new_imports.extend(get_relative_imports(f)) + + module_path = Path(module_file).parent + new_import_files = [f"{str(module_path / m)}.py" for m in new_imports] + files_to_check = [f for f in new_import_files if f not in all_relative_imports] + + no_change = len(files_to_check) == 0 + all_relative_imports.extend(files_to_check) + + return all_relative_imports + + +def get_imports(filename: Union[str, os.PathLike]) -> list[str]: + """ + Extracts all the libraries (not relative imports this time) that are imported in a file. + + Args: + filename (`str` or `os.PathLike`): The module file to inspect. + + Returns: + `list[str]`: The list of all packages required to use the input module. + """ + with open(filename, encoding="utf-8") as f: + content = f.read() + imported_modules = set() + + import transformers.utils + + def recursive_look_for_imports(node): + if isinstance(node, ast.Try): + return # Don't recurse into Try blocks and ignore imports in them + elif isinstance(node, ast.If): + test = node.test + for condition_node in ast.walk(test): + if isinstance(condition_node, ast.Call): + check_function = getattr(condition_node.func, "id", "") + if ( + check_function.endswith("available") + and check_function.startswith("is_flash_attn") + or hasattr(transformers.utils.import_utils, check_function) + ): + # Don't recurse into "if flash_attn_available()" or any "if library_available" blocks + # that appears in `transformers.utils.import_utils` and ignore imports in them + return + elif isinstance(node, ast.Import): + # Handle 'import x' statements + for alias in node.names: + top_module = alias.name.split(".")[0] + if top_module: + imported_modules.add(top_module) + elif isinstance(node, ast.ImportFrom): + # Handle 'from x import y' statements, ignoring relative imports + if node.level == 0 and node.module: + top_module = node.module.split(".")[0] + if top_module: + imported_modules.add(top_module) + + # Recursively visit all children + for child in ast.iter_child_nodes(node): + recursive_look_for_imports(child) + + tree = ast.parse(content) + recursive_look_for_imports(tree) + + return sorted(imported_modules) + + +def check_imports(filename: Union[str, os.PathLike]) -> list[str]: + """ + Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a + library is missing. + + Args: + filename (`str` or `os.PathLike`): The module file to check. + + Returns: + `list[str]`: The list of relative imports in the file. + """ + imports = get_imports(filename) + missing_packages = [] + for imp in imports: + try: + importlib.import_module(imp) + except ImportError as exception: + logger.warning(f"Encountered exception while importing {imp}: {exception}") + # Some packages can fail with an ImportError because of a dependency issue. + # This check avoids hiding such errors. + # See https://github.com/huggingface/transformers/issues/33604 + if "No module named" in str(exception): + missing_packages.append(imp) + else: + raise + + if len(missing_packages) > 0: + raise ImportError( + "This modeling file requires the following packages that were not found in your environment: " + f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`" + ) + + return get_relative_imports(filename) + + +def get_class_in_module( + class_name: str, + module_path: Union[str, os.PathLike], + *, + force_reload: bool = False, +) -> type: + """ + Import a module on the cache directory for modules and extract a class from it. + + Args: + class_name (`str`): The name of the class to import. + module_path (`str` or `os.PathLike`): The path to the module to import. + force_reload (`bool`, *optional*, defaults to `False`): + Whether to reload the dynamic module from file if it already exists in `sys.modules`. + Otherwise, the module is only reloaded if the file has changed. + + Returns: + `typing.Type`: The class looked for. + """ + name = os.path.normpath(module_path) + name = name.removesuffix(".py") + name = name.replace(os.path.sep, ".") + module_file: Path = Path(HF_MODULES_CACHE) / module_path + with _HF_REMOTE_CODE_LOCK: + if force_reload: + sys.modules.pop(name, None) + importlib.invalidate_caches() + cached_module: Optional[ModuleType] = sys.modules.get(name) + module_spec = importlib.util.spec_from_file_location(name, location=module_file) + + # Hash the module file and all its relative imports to check if we need to reload it + module_files: list[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file))) + module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest() + + module: ModuleType + if cached_module is None: + module = importlib.util.module_from_spec(module_spec) + # insert it into sys.modules before any loading begins + sys.modules[name] = module + else: + module = cached_module + # reload in both cases, unless the module is already imported and the hash hits + if getattr(module, "__transformers_module_hash__", "") != module_hash: + module_spec.loader.exec_module(module) + module.__transformers_module_hash__ = module_hash + return getattr(module, class_name) + + +def get_cached_module_file( + pretrained_model_name_or_path: Union[str, os.PathLike], + module_file: str, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: Optional[bool] = None, + proxies: Optional[dict[str, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + repo_type: Optional[str] = None, + _commit_hash: Optional[str] = None, + **deprecated_kwargs, +) -> str: + """ + Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached + Transformers module. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + module_file (`str`): + The name of the module file containing the class to look for. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `hf auth login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + repo_type (`str`, *optional*): + Specify the repo type (useful when downloading from a space for instance). + + + + Passing `token=True` is required when you want to use a private model. + + + + Returns: + `str`: The path to the module inside the cache. + """ + use_auth_token = deprecated_kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") + token = use_auth_token + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: + submodule = _sanitize_module_name(os.path.basename(pretrained_model_name_or_path)) + else: + submodule = os.path.sep.join(map(_sanitize_module_name, pretrained_model_name_or_path.split("/"))) + cached_module = try_to_load_from_cache( + pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type + ) + + new_files = [] + try: + # Load from URL or cache if already cached + resolved_module_file = cached_file( + pretrained_model_name_or_path, + module_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + revision=revision, + repo_type=repo_type, + _commit_hash=_commit_hash, + ) + if not is_local and cached_module != resolved_module_file: + new_files.append(module_file) + + except OSError: + logger.info(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") + raise + + # Check we have all the requirements in our environment + modules_needed = check_imports(resolved_module_file) + + # Now we move the module inside our cached dynamic modules. + full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule + create_dynamic_module(full_submodule) + submodule_path = Path(HF_MODULES_CACHE) / full_submodule + if submodule == _sanitize_module_name(os.path.basename(pretrained_model_name_or_path)): + # We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or + # has changed since last copy. + if not (submodule_path / module_file).exists() or not filecmp.cmp( + resolved_module_file, str(submodule_path / module_file) + ): + (submodule_path / module_file).parent.mkdir(parents=True, exist_ok=True) + shutil.copy(resolved_module_file, submodule_path / module_file) + importlib.invalidate_caches() + for module_needed in modules_needed: + module_needed = Path(module_file).parent / f"{module_needed}.py" + module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed) + if not (submodule_path / module_needed).exists() or not filecmp.cmp( + module_needed_file, str(submodule_path / module_needed) + ): + shutil.copy(module_needed_file, submodule_path / module_needed) + importlib.invalidate_caches() + else: + # Get the commit hash + commit_hash = extract_commit_hash(resolved_module_file, _commit_hash) + + # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the + # benefit of versioning. + submodule_path = submodule_path / commit_hash + full_submodule = full_submodule + os.path.sep + commit_hash + full_submodule_module_file_path = os.path.join(full_submodule, module_file) + create_dynamic_module(Path(full_submodule_module_file_path).parent) + + if not (submodule_path / module_file).exists(): + shutil.copy(resolved_module_file, submodule_path / module_file) + importlib.invalidate_caches() + # Make sure we also have every file with relative + for module_needed in modules_needed: + if not ((submodule_path / module_file).parent / f"{module_needed}.py").exists(): + get_cached_module_file( + pretrained_model_name_or_path, + f"{Path(module_file).parent / module_needed}.py", + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + _commit_hash=commit_hash, + ) + new_files.append(f"{module_needed}.py") + + if len(new_files) > 0 and revision is None: + new_files = "\n".join([f"- {f}" for f in new_files]) + repo_type_str = "" if repo_type is None else f"{repo_type}s/" + url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}" + logger.warning( + f"A new version of the following files was downloaded from {url}:\n{new_files}" + "\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new " + "versions of the code file, you can pin a revision." + ) + + return os.path.join(full_submodule, module_file) + + +def get_class_from_dynamic_module( + class_reference: str, + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: Optional[bool] = None, + proxies: Optional[dict[str, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + repo_type: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs, +) -> type: + """ + Extracts a class from a module file, present in the local folder or repository of a model. + + + + Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should + therefore only be called on trusted repos. + + + + + + Args: + class_reference (`str`): + The full name of the class to load, including its module and optionally its repo. + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + This is used when `class_reference` does not specify another repo. + module_file (`str`): + The name of the module file containing the class to look for. + class_name (`str`): + The name of the class to import in the module. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `hf auth login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + repo_type (`str`, *optional*): + Specify the repo type (useful when downloading from a space for instance). + code_revision (`str`, *optional*, defaults to `"main"`): + The specific revision to use for the code on the Hub, if the code leaves in a different repository than the + rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for + storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. + + + + Passing `token=True` is required when you want to use a private model. + + + + Returns: + `typing.Type`: The class, dynamically imported from the module. + + Examples: + + ```python + # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this + # module. + cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model") + + # Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this + # module. + cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model") + ```""" + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") + token = use_auth_token + + # Catch the name of the repo if it's specified in `class_reference` + if "--" in class_reference: + repo_id, class_reference = class_reference.split("--") + else: + repo_id = pretrained_model_name_or_path + module_file, class_name = class_reference.split(".") + + if code_revision is None and pretrained_model_name_or_path == repo_id: + code_revision = revision + # And lastly we get the class inside our newly created module + final_module = get_cached_module_file( + repo_id, + module_file + ".py", + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + revision=code_revision, + local_files_only=local_files_only, + repo_type=repo_type, + ) + return get_class_in_module(class_name, final_module, force_reload=force_download) + + +def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[dict] = None) -> list[str]: + """ + Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally + adds the proper fields in a config. + + Args: + obj (`Any`): The object for which to save the module files. + folder (`str` or `os.PathLike`): The folder where to save. + config (`PretrainedConfig` or dictionary, `optional`): + A config in which to register the auto_map corresponding to this custom object. + + Returns: + `list[str]`: The list of files saved. + """ + if obj.__module__ == "__main__": + logger.warning( + f"We can't save the code defining {obj} in {folder} as it's been defined in __main__. You should put " + "this code in a separate module so we can include it in the saved folder and make it easier to share via " + "the Hub." + ) + return + + def _set_auto_map_in_config(_config): + module_name = obj.__class__.__module__ + last_module = module_name.split(".")[-1] + full_name = f"{last_module}.{obj.__class__.__name__}" + # Special handling for tokenizers + if "Tokenizer" in full_name: + slow_tokenizer_class = None + fast_tokenizer_class = None + if obj.__class__.__name__.endswith("Fast"): + # Fast tokenizer: we have the fast tokenizer class and we may have the slow one has an attribute. + fast_tokenizer_class = f"{last_module}.{obj.__class__.__name__}" + if getattr(obj, "slow_tokenizer_class", None) is not None: + slow_tokenizer = getattr(obj, "slow_tokenizer_class") + slow_tok_module_name = slow_tokenizer.__module__ + last_slow_tok_module = slow_tok_module_name.split(".")[-1] + slow_tokenizer_class = f"{last_slow_tok_module}.{slow_tokenizer.__name__}" + else: + # Slow tokenizer: no way to have the fast class + slow_tokenizer_class = f"{last_module}.{obj.__class__.__name__}" + + full_name = (slow_tokenizer_class, fast_tokenizer_class) + + if isinstance(_config, dict): + auto_map = _config.get("auto_map", {}) + auto_map[obj._auto_class] = full_name + _config["auto_map"] = auto_map + elif getattr(_config, "auto_map", None) is not None: + _config.auto_map[obj._auto_class] = full_name + else: + _config.auto_map = {obj._auto_class: full_name} + + # Add object class to the config auto_map + if isinstance(config, (list, tuple)): + for cfg in config: + _set_auto_map_in_config(cfg) + elif config is not None: + _set_auto_map_in_config(config) + + result = [] + # Copy module file to the output folder. + object_file = sys.modules[obj.__module__].__file__ + dest_file = Path(folder) / (Path(object_file).name) + shutil.copy(object_file, dest_file) + result.append(dest_file) + + # Gather all relative imports recursively and make sure they are copied as well. + for needed_file in get_relative_import_files(object_file): + dest_file = Path(folder) / (Path(needed_file).name) + shutil.copy(needed_file, dest_file) + result.append(dest_file) + + return result + + +def _raise_timeout_error(signum, frame): + raise ValueError( + "Loading this model requires you to execute custom code contained in the model repository on your local " + "machine. Please set the option `trust_remote_code=True` to permit loading of this model." + ) + + +TIME_OUT_REMOTE_CODE = 15 + + +def resolve_trust_remote_code( + trust_remote_code, model_name, has_local_code, has_remote_code, error_message=None, upstream_repo=None +): + """ + Resolves the `trust_remote_code` argument. If there is remote code to be loaded, the user must opt-in to loading + it. + + Args: + trust_remote_code (`bool` or `None`): + User-defined `trust_remote_code` value. + model_name (`str`): + The name of the model repository in huggingface.co. + has_local_code (`bool`): + Whether the model has local code. + has_remote_code (`bool`): + Whether the model has remote code. + error_message (`str`, *optional*): + Custom error message to display if there is remote code to load and the user didn't opt-in. If unset, the error + message will be regarding loading a model with custom code. + + Returns: + The resolved `trust_remote_code` value. + """ + if error_message is None: + if upstream_repo is not None: + error_message = ( + f"The repository {model_name} references custom code contained in {upstream_repo} which " + f"must be executed to correctly load the model. You can inspect the repository " + f"content at https://hf.co/{upstream_repo} .\n" + ) + elif os.path.isdir(model_name): + error_message = ( + f"The repository {model_name} contains custom code which must be executed " + f"to correctly load the model. You can inspect the repository " + f"content at {os.path.abspath(model_name)} .\n" + ) + else: + error_message = ( + f"The repository {model_name} contains custom code which must be executed " + f"to correctly load the model. You can inspect the repository " + f"content at https://hf.co/{model_name} .\n" + ) + + if trust_remote_code is None: + if has_local_code: + trust_remote_code = False + elif has_remote_code and TIME_OUT_REMOTE_CODE > 0: + prev_sig_handler = None + try: + prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error) + signal.alarm(TIME_OUT_REMOTE_CODE) + while trust_remote_code is None: + answer = input( + f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n" + f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n" + f"Do you wish to run the custom code? [y/N] " + ) + if answer.lower() in ["yes", "y", "1"]: + trust_remote_code = True + elif answer.lower() in ["no", "n", "0", ""]: + trust_remote_code = False + signal.alarm(0) + except Exception: + # OS which does not support signal.SIGALRM + raise ValueError( + f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n" + f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." + ) + finally: + if prev_sig_handler is not None: + signal.signal(signal.SIGALRM, prev_sig_handler) + signal.alarm(0) + elif has_remote_code: + # For the CI which puts the timeout at 0 + _raise_timeout_error(None, None) + + if has_remote_code and not has_local_code and not trust_remote_code: + raise ValueError( + f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n" + f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." + ) + + return trust_remote_code + + +def check_python_requirements(path_or_repo_id, requirements_file="requirements.txt", **kwargs): + """ + Tries to locate `requirements_file` in a local folder or repo, and confirms that the environment has all the + python dependencies installed. + + Args: + path_or_repo_id (`str` or `os.PathLike`): + This can be either: + - a string, the *model id* of a model repo on huggingface.co. + - a path to a *directory* potentially containing the file. + kwargs (`dict[str, Any]`, *optional*): + Additional arguments to pass to `cached_file`. + """ + failed = [] # error messages regarding requirements + try: + requirements = cached_file(path_or_repo_id=path_or_repo_id, filename=requirements_file, **kwargs) + with open(requirements, "r") as f: + requirements = f.readlines() + + for requirement in requirements: + requirement = requirement.strip() + if not requirement or requirement.startswith("#"): # skip empty lines and comments + continue + + try: + # e.g. "torch>2.6.0" -> "torch", ">", "2.6.0" + package_name, delimiter, version_number = split_package_version(requirement) + except ValueError: # e.g. "torch", as opposed to "torch>2.6.0" + package_name = requirement + delimiter, version_number = None, None + + try: + local_package_version = importlib.metadata.version(package_name) + except importlib.metadata.PackageNotFoundError: + failed.append(f"{requirement} (installed: None)") + continue + + if delimiter is not None and version_number is not None: + is_satisfied = VersionComparison.from_string(delimiter)( + version.parse(local_package_version), version.parse(version_number) + ) + else: + is_satisfied = True + + if not is_satisfied: + failed.append(f"{requirement} (installed: {local_package_version})") + + except OSError: # no requirements.txt + pass + + if failed: + raise ImportError( + f"Missing requirements in your local environment for `{path_or_repo_id}`:\n" + "\n".join(failed) + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/feature_extraction_sequence_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/feature_extraction_sequence_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e0be17bd7d28cf61c70c177af4d025d0321617d7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/feature_extraction_sequence_utils.py @@ -0,0 +1,371 @@ +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Sequence feature extraction class for common feature extractors to preprocess sequences. +""" + +from typing import Optional, Union + +import numpy as np + +from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin +from .utils import PaddingStrategy, TensorType, is_tf_tensor, is_torch_tensor, logging, to_numpy + + +logger = logging.get_logger(__name__) + + +class SequenceFeatureExtractor(FeatureExtractionMixin): + """ + This is a general feature extraction class for speech recognition. + + Args: + feature_size (`int`): + The feature dimension of the extracted features. + sampling_rate (`int`): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + padding_value (`float`): + The value that is used to fill the padding values / vectors. + """ + + def __init__(self, feature_size: int, sampling_rate: int, padding_value: float, **kwargs): + self.feature_size = feature_size + self.sampling_rate = sampling_rate + self.padding_value = padding_value + + self.padding_side = kwargs.pop("padding_side", "right") + self.return_attention_mask = kwargs.pop("return_attention_mask", True) + + super().__init__(**kwargs) + + def pad( + self, + processed_features: Union[ + BatchFeature, + list[BatchFeature], + dict[str, BatchFeature], + dict[str, list[BatchFeature]], + list[dict[str, BatchFeature]], + ], + padding: Union[bool, str, PaddingStrategy] = True, + max_length: Optional[int] = None, + truncation: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> BatchFeature: + """ + Pad input values / input vectors or a batch of input values / input vectors up to predefined length or to the + max sequence length in the batch. + + Padding side (left/right) padding values are defined at the feature extractor level (with `self.padding_side`, + `self.padding_value`) + + + + If the `processed_features` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the + result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of + PyTorch tensors, you will lose the specific device of your tensors however. + + + + Args: + processed_features ([`BatchFeature`], list of [`BatchFeature`], `dict[str, list[float]]`, `dict[str, list[list[float]]` or `list[dict[str, list[float]]]`): + Processed inputs. Can represent one input ([`BatchFeature`] or `dict[str, list[float]]`) or a batch of + input values / vectors (list of [`BatchFeature`], *dict[str, list[list[float]]]* or *list[dict[str, + list[float]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader + collate function. + + Instead of `list[float]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), + see the note above for the return type. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific feature_extractor's default. + + [What are attention masks?](../glossary#attention-mask) + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + """ + # If we have a list of dicts, let's convert it in a dict of lists + # We do this to allow using this method as a collate_fn function in PyTorch Dataloader + if isinstance(processed_features, (list, tuple)) and isinstance(processed_features[0], (dict, BatchFeature)): + processed_features = { + key: [example[key] for example in processed_features] for key in processed_features[0] + } + + # The model's main input name, usually `input_values`, has be passed for padding + if self.model_input_names[0] not in processed_features: + raise ValueError( + "You should supply an instance of `transformers.BatchFeature` or list of `transformers.BatchFeature`" + f" to this method that includes {self.model_input_names[0]}, but you provided" + f" {list(processed_features.keys())}" + ) + + required_input = processed_features[self.model_input_names[0]] + return_attention_mask = ( + return_attention_mask if return_attention_mask is not None else self.return_attention_mask + ) + + if len(required_input) == 0: + if return_attention_mask: + processed_features["attention_mask"] = [] + return processed_features + + # If we have PyTorch/TF tensors or lists as inputs, we cast them as Numpy arrays + # and rebuild them afterwards if no return_tensors is specified + # Note that we lose the specific device the tensor may be on for PyTorch + + first_element = required_input[0] + if isinstance(first_element, (list, tuple)): + # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element. + index = 0 + while len(required_input[index]) == 0: + index += 1 + if index < len(required_input): + first_element = required_input[index][0] + + if return_tensors is None: + if is_tf_tensor(first_element): + return_tensors = "tf" + elif is_torch_tensor(first_element): + return_tensors = "pt" + elif isinstance(first_element, (int, float, list, tuple, np.ndarray)): + return_tensors = "np" + else: + raise ValueError( + f"type of {first_element} unknown: {type(first_element)}. " + "Should be one of a python, numpy, pytorch or tensorflow object." + ) + + for key, value in processed_features.items(): + if isinstance(value[0], (int, float)): + processed_features[key] = to_numpy(value) + else: + processed_features[key] = [to_numpy(v) for v in value] + + # Convert padding_strategy in PaddingStrategy + padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length) + + required_input = processed_features[self.model_input_names[0]] + + batch_size = len(required_input) + if not all(len(v) == batch_size for v in processed_features.values()): + raise ValueError("Some items in the output dictionary have a different batch size than others.") + + truncated_inputs = [] + for i in range(batch_size): + inputs = {k: v[i] for k, v in processed_features.items()} + # truncation + inputs_slice = self._truncate( + inputs, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + truncation=truncation, + ) + truncated_inputs.append(inputs_slice) + + if padding_strategy == PaddingStrategy.LONGEST: + # make sure that `max_length` cannot be longer than the longest truncated length + max_length = max(len(input_slice[self.model_input_names[0]]) for input_slice in truncated_inputs) + padding_strategy = PaddingStrategy.MAX_LENGTH + + batch_outputs = {} + for i in range(batch_size): + # padding + outputs = self._pad( + truncated_inputs[i], + max_length=max_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + if value.dtype is np.dtype(np.float64): + value = value.astype(np.float32) + batch_outputs[key].append(value) + + return BatchFeature(batch_outputs, tensor_type=return_tensors) + + def _pad( + self, + processed_features: Union[dict[str, np.ndarray], BatchFeature], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad inputs (on left/right and up to predefined length or max length in the batch) + + Args: + processed_features (`Union[dict[str, np.ndarray], BatchFeature]`): + Dictionary of input values (`np.ndarray[float]`) / input vectors (`list[np.ndarray[float]]`) or batch + of inputs values (`list[np.ndarray[int]]`) / input vectors (`list[np.ndarray[int]]`) + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see below) + padding_strategy (`PaddingStrategy`, *optional*, default to `PaddingStrategy.DO_NOT_PAD`): + PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The feature_extractor padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of (`int`, *optional*): + Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to + enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta), or on TPUs + which benefit from having sequence lengths be a multiple of 128. + return_attention_mask (`bool`, *optional*): + Set to False to avoid returning attention mask (default: set to model specifics) + """ + required_input = processed_features[self.model_input_names[0]] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) < max_length + + if return_attention_mask and "attention_mask" not in processed_features: + processed_features["attention_mask"] = np.ones(len(required_input), dtype=np.int32) + + if needs_to_be_padded: + difference = max_length - len(required_input) + if self.padding_side == "right": + if return_attention_mask: + processed_features["attention_mask"] = np.pad( + processed_features["attention_mask"], (0, difference) + ) + padding_shape = ((0, difference), (0, 0)) if self.feature_size > 1 else (0, difference) + processed_features[self.model_input_names[0]] = np.pad( + required_input, padding_shape, "constant", constant_values=self.padding_value + ) + elif self.padding_side == "left": + if return_attention_mask: + processed_features["attention_mask"] = np.pad( + processed_features["attention_mask"], (difference, 0) + ) + padding_shape = ((difference, 0), (0, 0)) if self.feature_size > 1 else (difference, 0) + processed_features[self.model_input_names[0]] = np.pad( + required_input, padding_shape, "constant", constant_values=self.padding_value + ) + else: + raise ValueError("Invalid padding strategy:" + str(self.padding_side)) + + return processed_features + + def _truncate( + self, + processed_features: Union[dict[str, np.ndarray], BatchFeature], + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + truncation: Optional[bool] = None, + ): + """ + Truncate inputs to predefined length or max length in the batch + + Args: + processed_features(`Union[dict[str, np.ndarray], BatchFeature]`): + Dictionary of input values (`np.ndarray[float]`) / input vectors (`list[np.ndarray[float]]`) or batch + of inputs values (`list[np.ndarray[int]]`) / input vectors (`list[np.ndarray[int]]`) + max_length (`int`, *optional*): + maximum length of the returned list and optionally padding length (see below) + pad_to_multiple_of (`int`, *optional*) : + Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to + enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta), or on TPUs + which benefit from having sequence lengths be a multiple of 128. + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + """ + if not truncation: + return processed_features + elif truncation and max_length is None: + raise ValueError("When setting ``truncation=True``, make sure that ``max_length`` is defined.") + + required_input = processed_features[self.model_input_names[0]] + + # find `max_length` that fits `pad_to_multiple_of` + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_truncated = len(required_input) > max_length + + if needs_to_be_truncated: + processed_features[self.model_input_names[0]] = processed_features[self.model_input_names[0]][:max_length] + if "attention_mask" in processed_features: + processed_features["attention_mask"] = processed_features["attention_mask"][:max_length] + + return processed_features + + def _get_padding_strategies(self, padding=False, max_length=None): + """ + Find the correct padding strategy + """ + + # Get padding strategy + if padding is not False: + if padding is True: + padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch + elif not isinstance(padding, PaddingStrategy): + padding_strategy = PaddingStrategy(padding) + elif isinstance(padding, PaddingStrategy): + padding_strategy = padding + else: + padding_strategy = PaddingStrategy.DO_NOT_PAD + + # Set max length if needed + if max_length is None: + if padding_strategy == PaddingStrategy.MAX_LENGTH: + raise ValueError( + f"When setting ``padding={PaddingStrategy.MAX_LENGTH}``, make sure that max_length is defined" + ) + + # Test if we have a padding value + if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.padding_value is None): + raise ValueError( + "Asking to pad but the feature_extractor does not have a padding value. Please select a value to use" + " as `padding_value`. For example: `feature_extractor.padding_value = 0.0`." + ) + + return padding_strategy diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/feature_extraction_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/feature_extraction_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e007e72d47612e377b3ce76c556d14e73a699095 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/feature_extraction_utils.py @@ -0,0 +1,697 @@ +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Feature extraction saving/loading class for common feature extractors. +""" + +import copy +import json +import os +import warnings +from collections import UserDict +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union + +import numpy as np + +from .dynamic_module_utils import custom_object_save +from .utils import ( + FEATURE_EXTRACTOR_NAME, + PROCESSOR_NAME, + PushToHubMixin, + TensorType, + copy_func, + download_url, + is_flax_available, + is_jax_tensor, + is_numpy_array, + is_offline_mode, + is_remote_url, + is_tf_available, + is_torch_available, + is_torch_device, + is_torch_dtype, + logging, + requires_backends, +) +from .utils.hub import cached_file + + +if TYPE_CHECKING: + from .feature_extraction_sequence_utils import SequenceFeatureExtractor + + +logger = logging.get_logger(__name__) + +PreTrainedFeatureExtractor = Union["SequenceFeatureExtractor"] + +# type hinting: specifying the type of feature extractor class that inherits from FeatureExtractionMixin +SpecificFeatureExtractorType = TypeVar("SpecificFeatureExtractorType", bound="FeatureExtractionMixin") + + +class BatchFeature(UserDict): + r""" + Holds the output of the [`~SequenceFeatureExtractor.pad`] and feature extractor specific `__call__` methods. + + This class is derived from a python dictionary and can be used as a dictionary. + + Args: + data (`dict`, *optional*): + Dictionary of lists/arrays/tensors returned by the __call__/pad methods ('input_values', 'attention_mask', + etc.). + tensor_type (`Union[None, str, TensorType]`, *optional*): + You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at + initialization. + """ + + def __init__(self, data: Optional[dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None): + super().__init__(data) + self.convert_to_tensors(tensor_type=tensor_type) + + def __getitem__(self, item: str) -> Any: + """ + If the key is a string, returns the value of the dict associated to `key` ('input_values', 'attention_mask', + etc.). + """ + if isinstance(item, str): + return self.data[item] + else: + raise KeyError("Indexing with integers is not available when using Python based feature extractors") + + def __getattr__(self, item: str): + try: + return self.data[item] + except KeyError: + raise AttributeError + + def __getstate__(self): + return {"data": self.data} + + def __setstate__(self, state): + if "data" in state: + self.data = state["data"] + + def _get_is_as_tensor_fns(self, tensor_type: Optional[Union[str, TensorType]] = None): + if tensor_type is None: + return None, None + + # Convert to TensorType + if not isinstance(tensor_type, TensorType): + tensor_type = TensorType(tensor_type) + + # Get a function reference for the correct framework + if tensor_type == TensorType.TENSORFLOW: + logger.warning_once( + "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We " + "recommend migrating to PyTorch classes or pinning your version of Transformers." + ) + if not is_tf_available(): + raise ImportError( + "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed." + ) + import tensorflow as tf + + as_tensor = tf.constant + is_tensor = tf.is_tensor + elif tensor_type == TensorType.PYTORCH: + if not is_torch_available(): + raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") + import torch + + def as_tensor(value): + if isinstance(value, (list, tuple)) and len(value) > 0: + if isinstance(value[0], np.ndarray): + value = np.array(value) + elif ( + isinstance(value[0], (list, tuple)) + and len(value[0]) > 0 + and isinstance(value[0][0], np.ndarray) + ): + value = np.array(value) + if isinstance(value, np.ndarray): + return torch.from_numpy(value) + else: + return torch.tensor(value) + + is_tensor = torch.is_tensor + elif tensor_type == TensorType.JAX: + logger.warning_once( + "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We " + "recommend migrating to PyTorch classes or pinning your version of Transformers." + ) + if not is_flax_available(): + raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.") + import jax.numpy as jnp # noqa: F811 + + as_tensor = jnp.array + is_tensor = is_jax_tensor + else: + + def as_tensor(value, dtype=None): + if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)): + value_lens = [len(val) for val in value] + if len(set(value_lens)) > 1 and dtype is None: + # we have a ragged list so handle explicitly + value = as_tensor([np.asarray(val) for val in value], dtype=object) + return np.asarray(value, dtype=dtype) + + is_tensor = is_numpy_array + return is_tensor, as_tensor + + def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None): + """ + Convert the inner content to tensors. + + Args: + tensor_type (`str` or [`~utils.TensorType`], *optional*): + The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If + `None`, no modification is done. + """ + if tensor_type is None: + return self + + is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type) + + # Do the tensor conversion in batch + for key, value in self.items(): + try: + if not is_tensor(value): + tensor = as_tensor(value) + + self[key] = tensor + except: # noqa E722 + if key == "overflowing_values": + raise ValueError("Unable to create tensor returning overflowing values of different lengths. ") + raise ValueError( + "Unable to create tensor, you should probably activate padding " + "with 'padding=True' to have batched tensors with the same length." + ) + + return self + + def to(self, *args, **kwargs) -> "BatchFeature": + """ + Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in + different `dtypes` and sending the `BatchFeature` to a different `device`. + + Args: + args (`Tuple`): + Will be passed to the `to(...)` function of the tensors. + kwargs (`Dict`, *optional*): + Will be passed to the `to(...)` function of the tensors. + To enable asynchronous data transfer, set the `non_blocking` flag in `kwargs` (defaults to `False`). + + Returns: + [`BatchFeature`]: The same instance after modification. + """ + requires_backends(self, ["torch"]) + import torch + + device = kwargs.get("device") + non_blocking = kwargs.get("non_blocking", False) + # Check if the args are a device or a dtype + if device is None and len(args) > 0: + # device should be always the first argument + arg = args[0] + if is_torch_dtype(arg): + # The first argument is a dtype + pass + elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): + device = arg + else: + # it's something else + raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") + + # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` + def maybe_to(v): + # check if v is a floating point + if isinstance(v, torch.Tensor) and torch.is_floating_point(v): + # cast and send to device + return v.to(*args, **kwargs) + elif isinstance(v, torch.Tensor) and device is not None: + return v.to(device=device, non_blocking=non_blocking) + else: + return v + + self.data = {k: maybe_to(v) for k, v in self.items()} + return self + + +class FeatureExtractionMixin(PushToHubMixin): + """ + This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature + extractors. + """ + + _auto_class = None + + def __init__(self, **kwargs): + """Set elements of `kwargs` as attributes.""" + # Pop "processor_class" as it should be saved as private attribute + self._processor_class = kwargs.pop("processor_class", None) + # Additional attributes without default values + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + def _set_processor_class(self, processor_class: str): + """Sets processor class as an attribute.""" + self._processor_class = processor_class + + @classmethod + def from_pretrained( + cls: type[SpecificFeatureExtractorType], + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + **kwargs, + ) -> SpecificFeatureExtractorType: + r""" + Instantiate a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a feature extractor, *e.g.* a + derived class of [`SequenceFeatureExtractor`]. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a feature extractor file saved using the + [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] method, e.g., + `./my_model_directory/`. + - a path or url to a saved feature extractor JSON *file*, e.g., + `./my_model_directory/preprocessor_config.json`. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model feature extractor should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the feature extractor files and override the cached versions + if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `hf auth login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + + To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. + + + + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + If `False`, then this function returns just the final feature extractor object. If `True`, then this + functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary + consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of + `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored. + kwargs (`dict[str, Any]`, *optional*): + The values in kwargs of any keys which are feature extractor attributes will be used to override the + loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is + controlled by the `return_unused_kwargs` keyword parameter. + + Returns: + A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`]. + + Examples: + + ```python + # We can't instantiate directly the base class *FeatureExtractionMixin* nor *SequenceFeatureExtractor* so let's show the examples on a + # derived class: *Wav2Vec2FeatureExtractor* + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + "facebook/wav2vec2-base-960h" + ) # Download feature_extraction_config from huggingface.co and cache. + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + "./test/saved_model/" + ) # E.g. feature_extractor (or model) was saved using *save_pretrained('./test/saved_model/')* + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("./test/saved_model/preprocessor_config.json") + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + "facebook/wav2vec2-base-960h", return_attention_mask=False, foo=False + ) + assert feature_extractor.return_attention_mask is False + feature_extractor, unused_kwargs = Wav2Vec2FeatureExtractor.from_pretrained( + "facebook/wav2vec2-base-960h", return_attention_mask=False, foo=False, return_unused_kwargs=True + ) + assert feature_extractor.return_attention_mask is False + assert unused_kwargs == {"foo": False} + ```""" + kwargs["cache_dir"] = cache_dir + kwargs["force_download"] = force_download + kwargs["local_files_only"] = local_files_only + kwargs["revision"] = revision + + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None: + kwargs["token"] = token + + feature_extractor_dict, kwargs = cls.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) + + return cls.from_dict(feature_extractor_dict, **kwargs) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a feature_extractor object to the directory `save_directory`, so that it can be re-loaded using the + [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the feature extractor JSON file will be saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + use_auth_token = kwargs.pop("use_auth_token", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if kwargs.get("token") is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + kwargs["token"] = use_auth_token + + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + + # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + custom_object_save(self, save_directory, config=self) + + # If we save using the predefined names, we can load using `from_pretrained` + output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME) + + self.to_json_file(output_feature_extractor_file) + logger.info(f"Feature extractor saved in {output_feature_extractor_file}") + + if push_to_hub: + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("token"), + ) + + return [output_feature_extractor_file] + + @classmethod + def get_feature_extractor_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> tuple[dict[str, Any], dict[str, Any]]: + """ + From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a + feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] using `from_dict`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. + + Returns: + `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the feature extractor object. + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + subfolder = kwargs.pop("subfolder", None) + token = kwargs.pop("token", None) + use_auth_token = kwargs.pop("use_auth_token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + + user_agent = {"file_type": "feature extractor", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME) + if os.path.isfile(pretrained_model_name_or_path): + resolved_feature_extractor_file = pretrained_model_name_or_path + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + feature_extractor_file = pretrained_model_name_or_path + resolved_feature_extractor_file = download_url(pretrained_model_name_or_path) + else: + feature_extractor_file = FEATURE_EXTRACTOR_NAME + try: + # Load from local folder or from cache or download from model Hub and cache + resolved_feature_extractor_files = [ + resolved_file + for filename in [feature_extractor_file, PROCESSOR_NAME] + if ( + resolved_file := cached_file( + pretrained_model_name_or_path, + filename=filename, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + subfolder=subfolder, + token=token, + user_agent=user_agent, + revision=revision, + _raise_exceptions_for_missing_entries=False, + ) + ) + is not None + ] + resolved_feature_extractor_file = resolved_feature_extractor_files[0] + except OSError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to + # the original exception. + raise + except Exception: + # For any other exception, we throw a generic error. + raise OSError( + f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load" + " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a {FEATURE_EXTRACTOR_NAME} file" + ) + + try: + # Load feature_extractor dict + with open(resolved_feature_extractor_file, encoding="utf-8") as reader: + text = reader.read() + feature_extractor_dict = json.loads(text) + feature_extractor_dict = feature_extractor_dict.get("feature_extractor", feature_extractor_dict) + + except json.JSONDecodeError: + raise OSError( + f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file." + ) + + if is_local: + logger.info(f"loading configuration file {resolved_feature_extractor_file}") + else: + logger.info( + f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}" + ) + + return feature_extractor_dict, kwargs + + @classmethod + def from_dict( + cls, feature_extractor_dict: dict[str, Any], **kwargs + ) -> Union["FeatureExtractionMixin", tuple["FeatureExtractionMixin", dict[str, Any]]]: + """ + Instantiates a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a Python dictionary of + parameters. + + Args: + feature_extractor_dict (`dict[str, Any]`): + Dictionary that will be used to instantiate the feature extractor object. Such a dictionary can be + retrieved from a pretrained checkpoint by leveraging the + [`~feature_extraction_utils.FeatureExtractionMixin.to_dict`] method. + kwargs (`dict[str, Any]`): + Additional parameters from which to initialize the feature extractor object. + + Returns: + [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature extractor object instantiated from those + parameters. + """ + return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + + # Update feature_extractor with kwargs if needed + to_remove = [] + for key, value in kwargs.items(): + if key in feature_extractor_dict: + feature_extractor_dict[key] = value + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + feature_extractor = cls(**feature_extractor_dict) + + logger.info(f"Feature extractor {feature_extractor}") + if return_unused_kwargs: + return feature_extractor, kwargs + else: + return feature_extractor + + def to_dict(self) -> dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + output = copy.deepcopy(self.__dict__) + output["feature_extractor_type"] = self.__class__.__name__ + if "mel_filters" in output: + del output["mel_filters"] + if "window" in output: + del output["window"] + return output + + @classmethod + def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "FeatureExtractionMixin": + """ + Instantiates a feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] from the path to + a JSON file of parameters. + + Args: + json_file (`str` or `os.PathLike`): + Path to the JSON file containing the parameters. + + Returns: + A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature_extractor + object instantiated from that JSON file. + """ + with open(json_file, encoding="utf-8") as reader: + text = reader.read() + feature_extractor_dict = json.loads(text) + return cls(**feature_extractor_dict) + + def to_json_string(self) -> str: + """ + Serializes this instance to a JSON string. + + Returns: + `str`: String containing all the attributes that make up this feature_extractor instance in JSON format. + """ + dictionary = self.to_dict() + + for key, value in dictionary.items(): + if isinstance(value, np.ndarray): + dictionary[key] = value.tolist() + + # make sure private name "_processor_class" is correctly + # saved as "processor_class" + _processor_class = dictionary.pop("_processor_class", None) + if _processor_class is not None: + dictionary["processor_class"] = _processor_class + + return json.dumps(dictionary, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this feature_extractor instance's parameters will be saved. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + @classmethod + def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"): + """ + Register this class with a given auto class. This should only be used for custom feature extractors as the ones + in the library are already mapped with `AutoFeatureExtractor`. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`): + The auto class to register this new feature extractor with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class + + +FeatureExtractionMixin.push_to_hub = copy_func(FeatureExtractionMixin.push_to_hub) +if FeatureExtractionMixin.push_to_hub.__doc__ is not None: + FeatureExtractionMixin.push_to_hub.__doc__ = FeatureExtractionMixin.push_to_hub.__doc__.format( + object="feature extractor", object_class="AutoFeatureExtractor", object_files="feature extractor file" + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/file_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fc6f722262d99e0123423e66d2a5db3edd118ff5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/file_utils.py @@ -0,0 +1,130 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +File utilities: utilities related to download and cache models + +This module should not be update anymore and is only left for backward compatibility. +""" + +from huggingface_hub import get_full_repo_name # for backward compatibility +from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY # for backward compatibility + +from . import __version__ + +# Backward compatibility imports, to make sure all those objects can be found in file_utils +from .utils import ( + CLOUDFRONT_DISTRIB_PREFIX, + CONFIG_NAME, + DUMMY_INPUTS, + DUMMY_MASK, + ENV_VARS_TRUE_AND_AUTO_VALUES, + ENV_VARS_TRUE_VALUES, + FEATURE_EXTRACTOR_NAME, + FLAX_WEIGHTS_NAME, + HF_MODULES_CACHE, + HUGGINGFACE_CO_PREFIX, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + MODEL_CARD_NAME, + MULTIPLE_CHOICE_DUMMY_INPUTS, + PYTORCH_PRETRAINED_BERT_CACHE, + PYTORCH_TRANSFORMERS_CACHE, + S3_BUCKET_PREFIX, + SENTENCEPIECE_UNDERLINE, + SPIECE_UNDERLINE, + TF2_WEIGHTS_NAME, + TF_WEIGHTS_NAME, + TORCH_FX_REQUIRED_VERSION, + TRANSFORMERS_CACHE, + TRANSFORMERS_DYNAMIC_MODULE_NAME, + USE_JAX, + USE_TF, + USE_TORCH, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + ContextManagers, + DummyObject, + EntryNotFoundError, + ExplicitEnum, + ModelOutput, + PaddingStrategy, + PushToHubMixin, + RepositoryNotFoundError, + RevisionNotFoundError, + TensorType, + _LazyModule, + add_code_sample_docstrings, + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + copy_func, + default_cache_path, + define_sagemaker_information, + get_torch_version, + has_file, + http_user_agent, + is_apex_available, + is_bs4_available, + is_coloredlogs_available, + is_datasets_available, + is_detectron2_available, + is_faiss_available, + is_flax_available, + is_ftfy_available, + is_g2p_en_available, + is_in_notebook, + is_ipex_available, + is_librosa_available, + is_offline_mode, + is_onnx_available, + is_pandas_available, + is_phonemizer_available, + is_protobuf_available, + is_psutil_available, + is_py3nvml_available, + is_pyctcdecode_available, + is_pytesseract_available, + is_pytorch_quantization_available, + is_rjieba_available, + is_sagemaker_dp_enabled, + is_sagemaker_mp_enabled, + is_scipy_available, + is_sentencepiece_available, + is_seqio_available, + is_sklearn_available, + is_soundfile_available, + is_spacy_available, + is_speech_available, + is_tensor, + is_tensorflow_probability_available, + is_tf2onnx_available, + is_tf_available, + is_timm_available, + is_tokenizers_available, + is_torch_available, + is_torch_bf16_available, + is_torch_cuda_available, + is_torch_fx_available, + is_torch_fx_proxy, + is_torch_mps_available, + is_torch_tf32_available, + is_torch_xla_available, + is_torchaudio_available, + is_training_run_on_sagemaker, + is_vision_available, + replace_return_docstrings, + requires_backends, + to_numpy, + to_py_obj, + torch_only_method, +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/hyperparameter_search.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/hyperparameter_search.py new file mode 100644 index 0000000000000000000000000000000000000000..e8558ceed32f4640c727a192d3ab00ed9104986f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/hyperparameter_search.py @@ -0,0 +1,141 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from .integrations import ( + is_optuna_available, + is_ray_tune_available, + is_sigopt_available, + is_wandb_available, + run_hp_search_optuna, + run_hp_search_ray, + run_hp_search_sigopt, + run_hp_search_wandb, +) +from .trainer_utils import ( + HPSearchBackend, + default_hp_space_optuna, + default_hp_space_ray, + default_hp_space_sigopt, + default_hp_space_wandb, +) +from .utils import logging + + +logger = logging.get_logger(__name__) + + +class HyperParamSearchBackendBase: + name: str + pip_package: Optional[str] = None + + @staticmethod + def is_available(): + raise NotImplementedError + + def run(self, trainer, n_trials: int, direction: str, **kwargs): + raise NotImplementedError + + def default_hp_space(self, trial): + raise NotImplementedError + + def ensure_available(self): + if not self.is_available(): + raise RuntimeError( + f"You picked the {self.name} backend, but it is not installed. Run {self.pip_install()}." + ) + + @classmethod + def pip_install(cls): + return f"`pip install {cls.pip_package or cls.name}`" + + +class OptunaBackend(HyperParamSearchBackendBase): + name = "optuna" + + @staticmethod + def is_available(): + return is_optuna_available() + + def run(self, trainer, n_trials: int, direction: str, **kwargs): + return run_hp_search_optuna(trainer, n_trials, direction, **kwargs) + + def default_hp_space(self, trial): + return default_hp_space_optuna(trial) + + +class RayTuneBackend(HyperParamSearchBackendBase): + name = "ray" + pip_package = "'ray[tune]'" + + @staticmethod + def is_available(): + return is_ray_tune_available() + + def run(self, trainer, n_trials: int, direction: str, **kwargs): + return run_hp_search_ray(trainer, n_trials, direction, **kwargs) + + def default_hp_space(self, trial): + return default_hp_space_ray(trial) + + +class SigOptBackend(HyperParamSearchBackendBase): + name = "sigopt" + + @staticmethod + def is_available(): + return is_sigopt_available() + + def run(self, trainer, n_trials: int, direction: str, **kwargs): + return run_hp_search_sigopt(trainer, n_trials, direction, **kwargs) + + def default_hp_space(self, trial): + return default_hp_space_sigopt(trial) + + +class WandbBackend(HyperParamSearchBackendBase): + name = "wandb" + + @staticmethod + def is_available(): + return is_wandb_available() + + def run(self, trainer, n_trials: int, direction: str, **kwargs): + return run_hp_search_wandb(trainer, n_trials, direction, **kwargs) + + def default_hp_space(self, trial): + return default_hp_space_wandb(trial) + + +ALL_HYPERPARAMETER_SEARCH_BACKENDS = { + HPSearchBackend(backend.name): backend for backend in [OptunaBackend, RayTuneBackend, SigOptBackend, WandbBackend] +} + + +def default_hp_search_backend() -> str: + available_backends = [backend for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() if backend.is_available()] + if len(available_backends) > 0: + name = available_backends[0].name + if len(available_backends) > 1: + logger.info( + f"{len(available_backends)} hyperparameter search backends available. Using {name} as the default." + ) + return name + raise RuntimeError( + "No hyperparameter search backend available.\n" + + "\n".join( + f" - To install {backend.name} run {backend.pip_install()}" + for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() + ) + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/image_processing_base.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/image_processing_base.py new file mode 100644 index 0000000000000000000000000000000000000000..dfe94ffd0df794588b91544ffe9764ee5717c327 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/image_processing_base.py @@ -0,0 +1,543 @@ +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import copy +import json +import os +import warnings +from typing import Any, Optional, TypeVar, Union + +import numpy as np + +from .dynamic_module_utils import custom_object_save +from .feature_extraction_utils import BatchFeature as BaseBatchFeature +from .image_utils import is_valid_image, load_image +from .utils import ( + IMAGE_PROCESSOR_NAME, + PROCESSOR_NAME, + PushToHubMixin, + copy_func, + download_url, + is_offline_mode, + is_remote_url, + logging, +) +from .utils.hub import cached_file + + +ImageProcessorType = TypeVar("ImageProcessorType", bound="ImageProcessingMixin") + + +logger = logging.get_logger(__name__) + + +# TODO: Move BatchFeature to be imported by both image_processing_utils and image_processing_utils_fast +# We override the class string here, but logic is the same. +class BatchFeature(BaseBatchFeature): + r""" + Holds the output of the image processor specific `__call__` methods. + + This class is derived from a python dictionary and can be used as a dictionary. + + Args: + data (`dict`): + Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.). + tensor_type (`Union[None, str, TensorType]`, *optional*): + You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at + initialization. + """ + + +# TODO: (Amy) - factor out the common parts of this and the feature extractor +class ImageProcessingMixin(PushToHubMixin): + """ + This is an image processor mixin used to provide saving/loading functionality for sequential and image feature + extractors. + """ + + _auto_class = None + + def __init__(self, **kwargs): + """Set elements of `kwargs` as attributes.""" + # This key was saved while we still used `XXXFeatureExtractor` for image processing. Now we use + # `XXXImageProcessor`, this attribute and its value are misleading. + kwargs.pop("feature_extractor_type", None) + # Pop "processor_class" as it should be saved as private attribute + self._processor_class = kwargs.pop("processor_class", None) + # Additional attributes without default values + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + def _set_processor_class(self, processor_class: str): + """Sets processor class as an attribute.""" + self._processor_class = processor_class + + @classmethod + def from_pretrained( + cls: type[ImageProcessorType], + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + **kwargs, + ) -> ImageProcessorType: + r""" + Instantiate a type of [`~image_processing_utils.ImageProcessingMixin`] from an image processor. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained image_processor hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a image processor file saved using the + [`~image_processing_utils.ImageProcessingMixin.save_pretrained`] method, e.g., + `./my_model_directory/`. + - a path or url to a saved image processor JSON *file*, e.g., + `./my_model_directory/preprocessor_config.json`. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model image processor should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the image processor files and override the cached versions if + they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `hf auth login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + + To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. + + + + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + If `False`, then this function returns just the final image processor object. If `True`, then this + functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary + consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of + `kwargs` which has not been used to update `image_processor` and is otherwise ignored. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + kwargs (`dict[str, Any]`, *optional*): + The values in kwargs of any keys which are image processor attributes will be used to override the + loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is + controlled by the `return_unused_kwargs` keyword parameter. + + Returns: + A image processor of type [`~image_processing_utils.ImageProcessingMixin`]. + + Examples: + + ```python + # We can't instantiate directly the base class *ImageProcessingMixin* so let's show the examples on a + # derived class: *CLIPImageProcessor* + image_processor = CLIPImageProcessor.from_pretrained( + "openai/clip-vit-base-patch32" + ) # Download image_processing_config from huggingface.co and cache. + image_processor = CLIPImageProcessor.from_pretrained( + "./test/saved_model/" + ) # E.g. image processor (or model) was saved using *save_pretrained('./test/saved_model/')* + image_processor = CLIPImageProcessor.from_pretrained("./test/saved_model/preprocessor_config.json") + image_processor = CLIPImageProcessor.from_pretrained( + "openai/clip-vit-base-patch32", do_normalize=False, foo=False + ) + assert image_processor.do_normalize is False + image_processor, unused_kwargs = CLIPImageProcessor.from_pretrained( + "openai/clip-vit-base-patch32", do_normalize=False, foo=False, return_unused_kwargs=True + ) + assert image_processor.do_normalize is False + assert unused_kwargs == {"foo": False} + ```""" + kwargs["cache_dir"] = cache_dir + kwargs["force_download"] = force_download + kwargs["local_files_only"] = local_files_only + kwargs["revision"] = revision + + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None: + kwargs["token"] = token + + image_processor_dict, kwargs = cls.get_image_processor_dict(pretrained_model_name_or_path, **kwargs) + + return cls.from_dict(image_processor_dict, **kwargs) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save an image processor object to the directory `save_directory`, so that it can be re-loaded using the + [`~image_processing_utils.ImageProcessingMixin.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the image processor JSON file will be saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + use_auth_token = kwargs.pop("use_auth_token", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if kwargs.get("token") is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + kwargs["token"] = use_auth_token + + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + + # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + custom_object_save(self, save_directory, config=self) + + # If we save using the predefined names, we can load using `from_pretrained` + output_image_processor_file = os.path.join(save_directory, IMAGE_PROCESSOR_NAME) + + self.to_json_file(output_image_processor_file) + logger.info(f"Image processor saved in {output_image_processor_file}") + + if push_to_hub: + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("token"), + ) + + return [output_image_processor_file] + + @classmethod + def get_image_processor_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> tuple[dict[str, Any], dict[str, Any]]: + """ + From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a + image processor of type [`~image_processor_utils.ImageProcessingMixin`] using `from_dict`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + image_processor_filename (`str`, *optional*, defaults to `"config.json"`): + The name of the file in the model directory to use for the image processor config. + + Returns: + `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object. + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + use_auth_token = kwargs.pop("use_auth_token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", "") + image_processor_filename = kwargs.pop("image_processor_filename", IMAGE_PROCESSOR_NAME) + + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + user_agent = {"file_type": "image processor", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + image_processor_file = os.path.join(pretrained_model_name_or_path, image_processor_filename) + if os.path.isfile(pretrained_model_name_or_path): + resolved_image_processor_file = pretrained_model_name_or_path + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + image_processor_file = pretrained_model_name_or_path + resolved_image_processor_file = download_url(pretrained_model_name_or_path) + else: + image_processor_file = image_processor_filename + try: + # Load from local folder or from cache or download from model Hub and cache + resolved_image_processor_files = [ + resolved_file + for filename in [image_processor_file, PROCESSOR_NAME] + if ( + resolved_file := cached_file( + pretrained_model_name_or_path, + filename=filename, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + ) + ) + is not None + ] + resolved_image_processor_file = resolved_image_processor_files[0] + except OSError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to + # the original exception. + raise + except Exception: + # For any other exception, we throw a generic error. + raise OSError( + f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load" + " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a {image_processor_filename} file" + ) + + try: + # Load image_processor dict + with open(resolved_image_processor_file, encoding="utf-8") as reader: + text = reader.read() + image_processor_dict = json.loads(text) + image_processor_dict = image_processor_dict.get("image_processor", image_processor_dict) + + except json.JSONDecodeError: + raise OSError( + f"It looks like the config file at '{resolved_image_processor_file}' is not a valid JSON file." + ) + + if is_local: + logger.info(f"loading configuration file {resolved_image_processor_file}") + else: + logger.info( + f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}" + ) + + return image_processor_dict, kwargs + + @classmethod + def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs): + """ + Instantiates a type of [`~image_processing_utils.ImageProcessingMixin`] from a Python dictionary of parameters. + + Args: + image_processor_dict (`dict[str, Any]`): + Dictionary that will be used to instantiate the image processor object. Such a dictionary can be + retrieved from a pretrained checkpoint by leveraging the + [`~image_processing_utils.ImageProcessingMixin.to_dict`] method. + kwargs (`dict[str, Any]`): + Additional parameters from which to initialize the image processor object. + + Returns: + [`~image_processing_utils.ImageProcessingMixin`]: The image processor object instantiated from those + parameters. + """ + image_processor_dict = image_processor_dict.copy() + return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + + # The `size` parameter is a dict and was previously an int or tuple in feature extractors. + # We set `size` here directly to the `image_processor_dict` so that it is converted to the appropriate + # dict within the image processor and isn't overwritten if `size` is passed in as a kwarg. + if "size" in kwargs and "size" in image_processor_dict: + image_processor_dict["size"] = kwargs.pop("size") + if "crop_size" in kwargs and "crop_size" in image_processor_dict: + image_processor_dict["crop_size"] = kwargs.pop("crop_size") + + image_processor = cls(**image_processor_dict) + + # Update image_processor with kwargs if needed + to_remove = [] + for key, value in kwargs.items(): + if hasattr(image_processor, key): + setattr(image_processor, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + logger.info(f"Image processor {image_processor}") + if return_unused_kwargs: + return image_processor, kwargs + else: + return image_processor + + def to_dict(self) -> dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + + Returns: + `dict[str, Any]`: Dictionary of all the attributes that make up this image processor instance. + """ + output = copy.deepcopy(self.__dict__) + output["image_processor_type"] = self.__class__.__name__ + + return output + + @classmethod + def from_json_file(cls, json_file: Union[str, os.PathLike]): + """ + Instantiates a image processor of type [`~image_processing_utils.ImageProcessingMixin`] from the path to a JSON + file of parameters. + + Args: + json_file (`str` or `os.PathLike`): + Path to the JSON file containing the parameters. + + Returns: + A image processor of type [`~image_processing_utils.ImageProcessingMixin`]: The image_processor object + instantiated from that JSON file. + """ + with open(json_file, encoding="utf-8") as reader: + text = reader.read() + image_processor_dict = json.loads(text) + return cls(**image_processor_dict) + + def to_json_string(self) -> str: + """ + Serializes this instance to a JSON string. + + Returns: + `str`: String containing all the attributes that make up this feature_extractor instance in JSON format. + """ + dictionary = self.to_dict() + + for key, value in dictionary.items(): + if isinstance(value, np.ndarray): + dictionary[key] = value.tolist() + + # make sure private name "_processor_class" is correctly + # saved as "processor_class" + _processor_class = dictionary.pop("_processor_class", None) + if _processor_class is not None: + dictionary["processor_class"] = _processor_class + + return json.dumps(dictionary, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this image_processor instance's parameters will be saved. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + @classmethod + def register_for_auto_class(cls, auto_class="AutoImageProcessor"): + """ + Register this class with a given auto class. This should only be used for custom image processors as the ones + in the library are already mapped with `AutoImageProcessor `. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"AutoImageProcessor "`): + The auto class to register this new image processor with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class + + def fetch_images(self, image_url_or_urls: Union[str, list[str], list[list[str]]]): + """ + Convert a single or a list of urls into the corresponding `PIL.Image` objects. + + If a single url is passed, the return value will be a single object. If a list is passed a list of objects is + returned. + """ + if isinstance(image_url_or_urls, list): + return [self.fetch_images(x) for x in image_url_or_urls] + elif isinstance(image_url_or_urls, str): + return load_image(image_url_or_urls) + elif is_valid_image(image_url_or_urls): + return image_url_or_urls + else: + raise TypeError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}") + + +ImageProcessingMixin.push_to_hub = copy_func(ImageProcessingMixin.push_to_hub) +if ImageProcessingMixin.push_to_hub.__doc__ is not None: + ImageProcessingMixin.push_to_hub.__doc__ = ImageProcessingMixin.push_to_hub.__doc__.format( + object="image processor", object_class="AutoImageProcessor", object_files="image processor file" + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/image_processing_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/image_processing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..52b798c09f84cf014b7568dd0634d739d32e0508 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/image_processing_utils.py @@ -0,0 +1,317 @@ +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Iterable +from typing import Optional, Union + +import numpy as np + +from .image_processing_base import BatchFeature, ImageProcessingMixin +from .image_transforms import center_crop, normalize, rescale +from .image_utils import ChannelDimension, get_image_size +from .utils import logging +from .utils.import_utils import requires + + +logger = logging.get_logger(__name__) + + +INIT_SERVICE_KWARGS = [ + "processor_class", + "image_processor_type", +] + + +@requires(backends=("vision",)) +class BaseImageProcessor(ImageProcessingMixin): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + @property + def is_fast(self) -> bool: + """ + `bool`: Whether or not this image processor is a fast processor (backed by PyTorch and TorchVision). + """ + return False + + def __call__(self, images, **kwargs) -> BatchFeature: + """Preprocess an image or a batch of images.""" + return self.preprocess(images, **kwargs) + + def preprocess(self, images, **kwargs) -> BatchFeature: + raise NotImplementedError("Each image processor must implement its own preprocess method") + + def rescale( + self, + image: np.ndarray, + scale: float, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Rescale an image by a scale factor. image = image * scale. + + Args: + image (`np.ndarray`): + Image to rescale. + scale (`float`): + The scaling factor to rescale pixel values by. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + + Returns: + `np.ndarray`: The rescaled image. + """ + return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs) + + def normalize( + self, + image: np.ndarray, + mean: Union[float, Iterable[float]], + std: Union[float, Iterable[float]], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Normalize an image. image = (image - image_mean) / image_std. + + Args: + image (`np.ndarray`): + Image to normalize. + mean (`float` or `Iterable[float]`): + Image mean to use for normalization. + std (`float` or `Iterable[float]`): + Image standard deviation to use for normalization. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + + Returns: + `np.ndarray`: The normalized image. + """ + return normalize( + image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs + ) + + def center_crop( + self, + image: np.ndarray, + size: dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along + any edge, the image is padded with 0's and then center cropped. + + Args: + image (`np.ndarray`): + Image to center crop. + size (`dict[str, int]`): + Size of the output image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}") + return center_crop( + image, + size=(size["height"], size["width"]), + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def to_dict(self): + encoder_dict = super().to_dict() + encoder_dict.pop("_valid_processor_keys", None) + return encoder_dict + + +VALID_SIZE_DICT_KEYS = ( + {"height", "width"}, + {"shortest_edge"}, + {"shortest_edge", "longest_edge"}, + {"longest_edge"}, + {"max_height", "max_width"}, +) + + +def is_valid_size_dict(size_dict): + if not isinstance(size_dict, dict): + return False + + size_dict_keys = set(size_dict.keys()) + for allowed_keys in VALID_SIZE_DICT_KEYS: + if size_dict_keys == allowed_keys: + return True + return False + + +def convert_to_size_dict( + size, max_size: Optional[int] = None, default_to_square: bool = True, height_width_order: bool = True +): + # By default, if size is an int we assume it represents a tuple of (size, size). + if isinstance(size, int) and default_to_square: + if max_size is not None: + raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size") + return {"height": size, "width": size} + # In other configs, if size is an int and default_to_square is False, size represents the length of + # the shortest edge after resizing. + elif isinstance(size, int) and not default_to_square: + size_dict = {"shortest_edge": size} + if max_size is not None: + size_dict["longest_edge"] = max_size + return size_dict + # Otherwise, if size is a tuple it's either (height, width) or (width, height) + elif isinstance(size, (tuple, list)) and height_width_order: + return {"height": size[0], "width": size[1]} + elif isinstance(size, (tuple, list)) and not height_width_order: + return {"height": size[1], "width": size[0]} + elif size is None and max_size is not None: + if default_to_square: + raise ValueError("Cannot specify both default_to_square=True and max_size") + return {"longest_edge": max_size} + + raise ValueError(f"Could not convert size input to size dict: {size}") + + +def get_size_dict( + size: Optional[Union[int, Iterable[int], dict[str, int]]] = None, + max_size: Optional[int] = None, + height_width_order: bool = True, + default_to_square: bool = True, + param_name="size", +) -> dict: + """ + Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards + compatibility with the old image processor configs and removes ambiguity over whether the tuple is in (height, + width) or (width, height) format. + + - If `size` is tuple, it is converted to `{"height": size[0], "width": size[1]}` or `{"height": size[1], "width": + size[0]}` if `height_width_order` is `False`. + - If `size` is an int, and `default_to_square` is `True`, it is converted to `{"height": size, "width": size}`. + - If `size` is an int and `default_to_square` is False, it is converted to `{"shortest_edge": size}`. If `max_size` + is set, it is added to the dict as `{"longest_edge": max_size}`. + + Args: + size (`Union[int, Iterable[int], dict[str, int]]`, *optional*): + The `size` parameter to be cast into a size dictionary. + max_size (`Optional[int]`, *optional*): + The `max_size` parameter to be cast into a size dictionary. + height_width_order (`bool`, *optional*, defaults to `True`): + If `size` is a tuple, whether it's in (height, width) or (width, height) order. + default_to_square (`bool`, *optional*, defaults to `True`): + If `size` is an int, whether to default to a square image or not. + """ + if not isinstance(size, dict): + size_dict = convert_to_size_dict(size, max_size, default_to_square, height_width_order) + logger.info( + f"{param_name} should be a dictionary on of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size}." + f" Converted to {size_dict}.", + ) + else: + size_dict = size + + if not is_valid_size_dict(size_dict): + raise ValueError( + f"{param_name} must have one of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size_dict.keys()}" + ) + return size_dict + + +def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple: + """ + Selects the best resolution from a list of possible resolutions based on the original size. + + This is done by calculating the effective and wasted resolution for each possible resolution. + + The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution. + + Args: + original_size (tuple): + The original size of the image in the format (height, width). + possible_resolutions (list): + A list of possible resolutions in the format [(height1, width1), (height2, width2), ...]. + + Returns: + tuple: The best fit resolution in the format (height, width). + """ + original_height, original_width = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float("inf") + + for height, width in possible_resolutions: + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) + effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or ( + effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution + ): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (height, width) + + return best_fit + + +def get_patch_output_size(image, target_resolution, input_data_format): + """ + Given an image and a target resolution, calculate the output size of the image after cropping to the target + """ + original_height, original_width = get_image_size(image, channel_dim=input_data_format) + target_height, target_width = target_resolution + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.ceil(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.ceil(original_width * scale_h), target_width) + + return new_height, new_width diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/image_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c5f4d4a3fa4cd2540d76c919a6548f52cbed76b4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/image_utils.py @@ -0,0 +1,969 @@ +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import os +from collections.abc import Iterable +from dataclasses import dataclass +from io import BytesIO +from typing import Optional, Union + +import numpy as np +import requests + +from .utils import ( + ExplicitEnum, + is_jax_tensor, + is_numpy_array, + is_tf_tensor, + is_torch_available, + is_torch_tensor, + is_torchvision_available, + is_vision_available, + logging, + requires_backends, + to_numpy, +) +from .utils.constants import ( # noqa: F401 + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, +) + + +if is_vision_available(): + import PIL.Image + import PIL.ImageOps + + PILImageResampling = PIL.Image.Resampling + + if is_torchvision_available(): + from torchvision.transforms import InterpolationMode + + pil_torch_interpolation_mapping = { + PILImageResampling.NEAREST: InterpolationMode.NEAREST_EXACT, + PILImageResampling.BOX: InterpolationMode.BOX, + PILImageResampling.BILINEAR: InterpolationMode.BILINEAR, + PILImageResampling.HAMMING: InterpolationMode.HAMMING, + PILImageResampling.BICUBIC: InterpolationMode.BICUBIC, + PILImageResampling.LANCZOS: InterpolationMode.LANCZOS, + } + else: + pil_torch_interpolation_mapping = {} + + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +ImageInput = Union[ + "PIL.Image.Image", np.ndarray, "torch.Tensor", list["PIL.Image.Image"], list[np.ndarray], list["torch.Tensor"] +] + + +class ChannelDimension(ExplicitEnum): + FIRST = "channels_first" + LAST = "channels_last" + + +class AnnotationFormat(ExplicitEnum): + COCO_DETECTION = "coco_detection" + COCO_PANOPTIC = "coco_panoptic" + + +class AnnotionFormat(ExplicitEnum): + COCO_DETECTION = AnnotationFormat.COCO_DETECTION.value + COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value + + +AnnotationType = dict[str, Union[int, str, list[dict]]] + + +def is_pil_image(img): + return is_vision_available() and isinstance(img, PIL.Image.Image) + + +class ImageType(ExplicitEnum): + PIL = "pillow" + TORCH = "torch" + NUMPY = "numpy" + TENSORFLOW = "tensorflow" + JAX = "jax" + + +def get_image_type(image): + if is_pil_image(image): + return ImageType.PIL + if is_torch_tensor(image): + return ImageType.TORCH + if is_numpy_array(image): + return ImageType.NUMPY + if is_tf_tensor(image): + return ImageType.TENSORFLOW + if is_jax_tensor(image): + return ImageType.JAX + raise ValueError(f"Unrecognized image type {type(image)}") + + +def is_valid_image(img): + return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img) or is_tf_tensor(img) or is_jax_tensor(img) + + +def is_valid_list_of_images(images: list): + return images and all(is_valid_image(image) for image in images) + + +def concatenate_list(input_list): + if isinstance(input_list[0], list): + return [item for sublist in input_list for item in sublist] + elif isinstance(input_list[0], np.ndarray): + return np.concatenate(input_list, axis=0) + elif isinstance(input_list[0], torch.Tensor): + return torch.cat(input_list, dim=0) + + +def valid_images(imgs): + # If we have an list of images, make sure every image is valid + if isinstance(imgs, (list, tuple)): + for img in imgs: + if not valid_images(img): + return False + # If not a list of tuple, we have been given a single image or batched tensor of images + elif not is_valid_image(imgs): + return False + return True + + +def is_batched(img): + if isinstance(img, (list, tuple)): + return is_valid_image(img[0]) + return False + + +def is_scaled_image(image: np.ndarray) -> bool: + """ + Checks to see whether the pixel values have already been rescaled to [0, 1]. + """ + if image.dtype == np.uint8: + return False + + # It's possible the image has pixel values in [0, 255] but is of floating type + return np.min(image) >= 0 and np.max(image) <= 1 + + +def make_list_of_images(images, expected_ndims: int = 3) -> list[ImageInput]: + """ + Ensure that the output is a list of images. If the input is a single image, it is converted to a list of length 1. + If the input is a batch of images, it is converted to a list of images. + + Args: + images (`ImageInput`): + Image of images to turn into a list of images. + expected_ndims (`int`, *optional*, defaults to 3): + Expected number of dimensions for a single input image. If the input image has a different number of + dimensions, an error is raised. + """ + if is_batched(images): + return images + + # Either the input is a single image, in which case we create a list of length 1 + if is_pil_image(images): + # PIL images are never batched + return [images] + + if is_valid_image(images): + if images.ndim == expected_ndims + 1: + # Batch of images + images = list(images) + elif images.ndim == expected_ndims: + # Single image + images = [images] + else: + raise ValueError( + f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got" + f" {images.ndim} dimensions." + ) + return images + raise ValueError( + "Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or " + f"jax.ndarray, but got {type(images)}." + ) + + +def make_flat_list_of_images( + images: Union[list[ImageInput], ImageInput], + expected_ndims: int = 3, +) -> ImageInput: + """ + Ensure that the output is a flat list of images. If the input is a single image, it is converted to a list of length 1. + If the input is a nested list of images, it is converted to a flat list of images. + Args: + images (`Union[list[ImageInput], ImageInput]`): + The input image. + expected_ndims (`int`, *optional*, defaults to 3): + The expected number of dimensions for a single input image. + Returns: + list: A list of images or a 4d array of images. + """ + # If the input is a nested list of images, we flatten it + if ( + isinstance(images, (list, tuple)) + and all(isinstance(images_i, (list, tuple)) for images_i in images) + and all(is_valid_list_of_images(images_i) or not images_i for images_i in images) + ): + return [img for img_list in images for img in img_list] + + if isinstance(images, (list, tuple)) and is_valid_list_of_images(images): + if is_pil_image(images[0]) or images[0].ndim == expected_ndims: + return images + if images[0].ndim == expected_ndims + 1: + return [img for img_list in images for img in img_list] + + if is_valid_image(images): + if is_pil_image(images) or images.ndim == expected_ndims: + return [images] + if images.ndim == expected_ndims + 1: + return list(images) + + raise ValueError(f"Could not make a flat list of images from {images}") + + +def make_nested_list_of_images( + images: Union[list[ImageInput], ImageInput], + expected_ndims: int = 3, +) -> list[ImageInput]: + """ + Ensure that the output is a nested list of images. + Args: + images (`Union[list[ImageInput], ImageInput]`): + The input image. + expected_ndims (`int`, *optional*, defaults to 3): + The expected number of dimensions for a single input image. + Returns: + list: A list of list of images or a list of 4d array of images. + """ + # If it's a list of batches, it's already in the right format + if ( + isinstance(images, (list, tuple)) + and all(isinstance(images_i, (list, tuple)) for images_i in images) + and all(is_valid_list_of_images(images_i) or not images_i for images_i in images) + ): + return images + + # If it's a list of images, it's a single batch, so convert it to a list of lists + if isinstance(images, (list, tuple)) and is_valid_list_of_images(images): + if is_pil_image(images[0]) or images[0].ndim == expected_ndims: + return [images] + if images[0].ndim == expected_ndims + 1: + return [list(image) for image in images] + + # If it's a single image, convert it to a list of lists + if is_valid_image(images): + if is_pil_image(images) or images.ndim == expected_ndims: + return [[images]] + if images.ndim == expected_ndims + 1: + return [list(images)] + + raise ValueError("Invalid input type. Must be a single image, a list of images, or a list of batches of images.") + + +def to_numpy_array(img) -> np.ndarray: + if not is_valid_image(img): + raise ValueError(f"Invalid image type: {type(img)}") + + if is_vision_available() and isinstance(img, PIL.Image.Image): + return np.array(img) + return to_numpy(img) + + +def infer_channel_dimension_format( + image: np.ndarray, num_channels: Optional[Union[int, tuple[int, ...]]] = None +) -> ChannelDimension: + """ + Infers the channel dimension format of `image`. + + Args: + image (`np.ndarray`): + The image to infer the channel dimension of. + num_channels (`int` or `tuple[int, ...]`, *optional*, defaults to `(1, 3)`): + The number of channels of the image. + + Returns: + The channel dimension of the image. + """ + num_channels = num_channels if num_channels is not None else (1, 3) + num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels + + if image.ndim == 3: + first_dim, last_dim = 0, 2 + elif image.ndim == 4: + first_dim, last_dim = 1, 3 + elif image.ndim == 5: + first_dim, last_dim = 2, 4 + else: + raise ValueError(f"Unsupported number of image dimensions: {image.ndim}") + + if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels: + logger.warning( + f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension. Use the [input_data_format](https://huggingface.co/docs/transformers/main/internal/image_processing_utils#transformers.image_transforms.rescale.input_data_format) parameter to assign the channel dimension." + ) + return ChannelDimension.FIRST + elif image.shape[first_dim] in num_channels: + return ChannelDimension.FIRST + elif image.shape[last_dim] in num_channels: + return ChannelDimension.LAST + raise ValueError("Unable to infer channel dimension format") + + +def get_channel_dimension_axis( + image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None +) -> int: + """ + Returns the channel dimension axis of the image. + + Args: + image (`np.ndarray`): + The image to get the channel dimension axis of. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the image. If `None`, will infer the channel dimension from the image. + + Returns: + The channel dimension axis of the image. + """ + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + if input_data_format == ChannelDimension.FIRST: + return image.ndim - 3 + elif input_data_format == ChannelDimension.LAST: + return image.ndim - 1 + raise ValueError(f"Unsupported data format: {input_data_format}") + + +def get_image_size(image: np.ndarray, channel_dim: Optional[ChannelDimension] = None) -> tuple[int, int]: + """ + Returns the (height, width) dimensions of the image. + + Args: + image (`np.ndarray`): + The image to get the dimensions of. + channel_dim (`ChannelDimension`, *optional*): + Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image. + + Returns: + A tuple of the image's height and width. + """ + if channel_dim is None: + channel_dim = infer_channel_dimension_format(image) + + if channel_dim == ChannelDimension.FIRST: + return image.shape[-2], image.shape[-1] + elif channel_dim == ChannelDimension.LAST: + return image.shape[-3], image.shape[-2] + else: + raise ValueError(f"Unsupported data format: {channel_dim}") + + +def get_image_size_for_max_height_width( + image_size: tuple[int, int], + max_height: int, + max_width: int, +) -> tuple[int, int]: + """ + Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio. + Important, even if image_height < max_height and image_width < max_width, the image will be resized + to at least one of the edges be equal to max_height or max_width. + + For example: + - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50) + - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400) + + Args: + image_size (`tuple[int, int]`): + The image to resize. + max_height (`int`): + The maximum allowed height. + max_width (`int`): + The maximum allowed width. + """ + height, width = image_size + height_scale = max_height / height + width_scale = max_width / width + min_scale = min(height_scale, width_scale) + new_height = int(height * min_scale) + new_width = int(width * min_scale) + return new_height, new_width + + +def is_valid_annotation_coco_detection(annotation: dict[str, Union[list, tuple]]) -> bool: + if ( + isinstance(annotation, dict) + and "image_id" in annotation + and "annotations" in annotation + and isinstance(annotation["annotations"], (list, tuple)) + and ( + # an image can have no annotations + len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict) + ) + ): + return True + return False + + +def is_valid_annotation_coco_panoptic(annotation: dict[str, Union[list, tuple]]) -> bool: + if ( + isinstance(annotation, dict) + and "image_id" in annotation + and "segments_info" in annotation + and "file_name" in annotation + and isinstance(annotation["segments_info"], (list, tuple)) + and ( + # an image can have no segments + len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict) + ) + ): + return True + return False + + +def valid_coco_detection_annotations(annotations: Iterable[dict[str, Union[list, tuple]]]) -> bool: + return all(is_valid_annotation_coco_detection(ann) for ann in annotations) + + +def valid_coco_panoptic_annotations(annotations: Iterable[dict[str, Union[list, tuple]]]) -> bool: + return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations) + + +def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image": + """ + Loads `image` to a PIL Image. + + Args: + image (`str` or `PIL.Image.Image`): + The image to convert to the PIL Image format. + timeout (`float`, *optional*): + The timeout value in seconds for the URL request. + + Returns: + `PIL.Image.Image`: A PIL Image. + """ + requires_backends(load_image, ["vision"]) + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + # We need to actually check for a real protocol, otherwise it's impossible to use a local file + # like http_huggingface_co.png + image = PIL.Image.open(BytesIO(requests.get(image, timeout=timeout).content)) + elif os.path.isfile(image): + image = PIL.Image.open(image) + else: + if image.startswith("data:image/"): + image = image.split(",")[1] + + # Try to load as base64 + try: + b64 = base64.decodebytes(image.encode()) + image = PIL.Image.open(BytesIO(b64)) + except Exception as e: + raise ValueError( + f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}" + ) + elif not isinstance(image, PIL.Image.Image): + raise TypeError( + "Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image." + ) + image = PIL.ImageOps.exif_transpose(image) + image = image.convert("RGB") + return image + + +def load_images( + images: Union[list, tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None +) -> Union["PIL.Image.Image", list["PIL.Image.Image"], list[list["PIL.Image.Image"]]]: + """Loads images, handling different levels of nesting. + + Args: + images: A single image, a list of images, or a list of lists of images to load. + timeout: Timeout for loading images. + + Returns: + A single image, a list of images, a list of lists of images. + """ + if isinstance(images, (list, tuple)): + if len(images) and isinstance(images[0], (list, tuple)): + return [[load_image(image, timeout=timeout) for image in image_group] for image_group in images] + else: + return [load_image(image, timeout=timeout) for image in images] + else: + return load_image(images, timeout=timeout) + + +def validate_preprocess_arguments( + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_pad: Optional[bool] = None, + pad_size: Optional[Union[dict[str, int], int]] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[dict[str, int]] = None, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional["PILImageResampling"] = None, + interpolation: Optional["InterpolationMode"] = None, +): + """ + Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method. + Raises `ValueError` if arguments incompatibility is caught. + Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`, + sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow + existing arguments when possible. + + """ + if do_rescale and rescale_factor is None: + raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.") + + if do_pad and pad_size is None: + # Processors pad images using different args depending on the model, so the below check is pointless + # but we keep it for BC for now. TODO: remove in v5 + # Usually padding can be called with: + # - "pad_size/size" if we're padding to specific values + # - "size_divisor" if we're padding to any value divisible by X + # - "None" if we're padding to the maximum size image in batch + raise ValueError( + "Depending on the model, `size_divisor` or `pad_size` or `size` must be specified if `do_pad` is `True`." + ) + + if do_normalize and (image_mean is None or image_std is None): + raise ValueError("`image_mean` and `image_std` must both be specified if `do_normalize` is `True`.") + + if do_center_crop and crop_size is None: + raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.") + + if interpolation is not None and resample is not None: + raise ValueError( + "Only one of `interpolation` and `resample` should be specified, depending on image processor type." + ) + + if do_resize and not (size is not None and (resample is not None or interpolation is not None)): + raise ValueError("`size` and `resample/interpolation` must be specified if `do_resize` is `True`.") + + +# In the future we can add a TF implementation here when we have TF models. +class ImageFeatureExtractionMixin: + """ + Mixin that contain utilities for preparing image features. + """ + + def _ensure_format_supported(self, image): + if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image): + raise ValueError( + f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.ndarray` and " + "`torch.Tensor` are." + ) + + def to_pil_image(self, image, rescale=None): + """ + Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if + needed. + + Args: + image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`): + The image to convert to the PIL Image format. + rescale (`bool`, *optional*): + Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will + default to `True` if the image type is a floating type, `False` otherwise. + """ + self._ensure_format_supported(image) + + if is_torch_tensor(image): + image = image.numpy() + + if isinstance(image, np.ndarray): + if rescale is None: + # rescale default to the array being of floating type. + rescale = isinstance(image.flat[0], np.floating) + # If the channel as been moved to first dim, we put it back at the end. + if image.ndim == 3 and image.shape[0] in [1, 3]: + image = image.transpose(1, 2, 0) + if rescale: + image = image * 255 + image = image.astype(np.uint8) + return PIL.Image.fromarray(image) + return image + + def convert_rgb(self, image): + """ + Converts `PIL.Image.Image` to RGB format. + + Args: + image (`PIL.Image.Image`): + The image to convert. + """ + self._ensure_format_supported(image) + if not isinstance(image, PIL.Image.Image): + return image + + return image.convert("RGB") + + def rescale(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray: + """ + Rescale a numpy image by scale amount + """ + self._ensure_format_supported(image) + return image * scale + + def to_numpy_array(self, image, rescale=None, channel_first=True): + """ + Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first + dimension. + + Args: + image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): + The image to convert to a NumPy array. + rescale (`bool`, *optional*): + Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will + default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise. + channel_first (`bool`, *optional*, defaults to `True`): + Whether or not to permute the dimensions of the image to put the channel dimension first. + """ + self._ensure_format_supported(image) + + if isinstance(image, PIL.Image.Image): + image = np.array(image) + + if is_torch_tensor(image): + image = image.numpy() + + rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale + + if rescale: + image = self.rescale(image.astype(np.float32), 1 / 255.0) + + if channel_first and image.ndim == 3: + image = image.transpose(2, 0, 1) + + return image + + def expand_dims(self, image): + """ + Expands 2-dimensional `image` to 3 dimensions. + + Args: + image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): + The image to expand. + """ + self._ensure_format_supported(image) + + # Do nothing if PIL image + if isinstance(image, PIL.Image.Image): + return image + + if is_torch_tensor(image): + image = image.unsqueeze(0) + else: + image = np.expand_dims(image, axis=0) + return image + + def normalize(self, image, mean, std, rescale=False): + """ + Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array + if it's a PIL Image. + + Args: + image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): + The image to normalize. + mean (`list[float]` or `np.ndarray` or `torch.Tensor`): + The mean (per channel) to use for normalization. + std (`list[float]` or `np.ndarray` or `torch.Tensor`): + The standard deviation (per channel) to use for normalization. + rescale (`bool`, *optional*, defaults to `False`): + Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will + happen automatically. + """ + self._ensure_format_supported(image) + + if isinstance(image, PIL.Image.Image): + image = self.to_numpy_array(image, rescale=True) + # If the input image is a PIL image, it automatically gets rescaled. If it's another + # type it may need rescaling. + elif rescale: + if isinstance(image, np.ndarray): + image = self.rescale(image.astype(np.float32), 1 / 255.0) + elif is_torch_tensor(image): + image = self.rescale(image.float(), 1 / 255.0) + + if isinstance(image, np.ndarray): + if not isinstance(mean, np.ndarray): + mean = np.array(mean).astype(image.dtype) + if not isinstance(std, np.ndarray): + std = np.array(std).astype(image.dtype) + elif is_torch_tensor(image): + import torch + + if not isinstance(mean, torch.Tensor): + if isinstance(mean, np.ndarray): + mean = torch.from_numpy(mean) + else: + mean = torch.tensor(mean) + if not isinstance(std, torch.Tensor): + if isinstance(std, np.ndarray): + std = torch.from_numpy(std) + else: + std = torch.tensor(std) + + if image.ndim == 3 and image.shape[0] in [1, 3]: + return (image - mean[:, None, None]) / std[:, None, None] + else: + return (image - mean) / std + + def resize(self, image, size, resample=None, default_to_square=True, max_size=None): + """ + Resizes `image`. Enforces conversion of input to PIL.Image. + + Args: + image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): + The image to resize. + size (`int` or `tuple[int, int]`): + The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be + matched to this. + + If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If + `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to + this number. i.e, if height > width, then image will be rescaled to (size * height / width, size). + resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`): + The filter to user for resampling. + default_to_square (`bool`, *optional*, defaults to `True`): + How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a + square (`size`,`size`). If set to `False`, will replicate + [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize) + with support for resizing only the smallest edge and providing an optional `max_size`. + max_size (`int`, *optional*, defaults to `None`): + The maximum allowed for the longer edge of the resized image: if the longer edge of the image is + greater than `max_size` after being resized according to `size`, then the image is resized again so + that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller + edge may be shorter than `size`. Only used if `default_to_square` is `False`. + + Returns: + image: A resized `PIL.Image.Image`. + """ + resample = resample if resample is not None else PILImageResampling.BILINEAR + + self._ensure_format_supported(image) + + if not isinstance(image, PIL.Image.Image): + image = self.to_pil_image(image) + + if isinstance(size, list): + size = tuple(size) + + if isinstance(size, int) or len(size) == 1: + if default_to_square: + size = (size, size) if isinstance(size, int) else (size[0], size[0]) + else: + width, height = image.size + # specified size only for the smallest edge + short, long = (width, height) if width <= height else (height, width) + requested_new_short = size if isinstance(size, int) else size[0] + + if short == requested_new_short: + return image + + new_short, new_long = requested_new_short, int(requested_new_short * long / short) + + if max_size is not None: + if max_size <= requested_new_short: + raise ValueError( + f"max_size = {max_size} must be strictly greater than the requested " + f"size for the smaller edge size = {size}" + ) + if new_long > max_size: + new_short, new_long = int(max_size * new_short / new_long), max_size + + size = (new_short, new_long) if width <= height else (new_long, new_short) + + return image.resize(size, resample=resample) + + def center_crop(self, image, size): + """ + Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the + size given, it will be padded (so the returned result has the size asked). + + Args: + image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)): + The image to resize. + size (`int` or `tuple[int, int]`): + The size to which crop the image. + + Returns: + new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape: (n_channels, + height, width). + """ + self._ensure_format_supported(image) + + if not isinstance(size, tuple): + size = (size, size) + + # PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width) + if is_torch_tensor(image) or isinstance(image, np.ndarray): + if image.ndim == 2: + image = self.expand_dims(image) + image_shape = image.shape[1:] if image.shape[0] in [1, 3] else image.shape[:2] + else: + image_shape = (image.size[1], image.size[0]) + + top = (image_shape[0] - size[0]) // 2 + bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result. + left = (image_shape[1] - size[1]) // 2 + right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result. + + # For PIL Images we have a method to crop directly. + if isinstance(image, PIL.Image.Image): + return image.crop((left, top, right, bottom)) + + # Check if image is in (n_channels, height, width) or (height, width, n_channels) format + channel_first = image.shape[0] in [1, 3] + + # Transpose (height, width, n_channels) format images + if not channel_first: + if isinstance(image, np.ndarray): + image = image.transpose(2, 0, 1) + if is_torch_tensor(image): + image = image.permute(2, 0, 1) + + # Check if cropped area is within image boundaries + if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]: + return image[..., top:bottom, left:right] + + # Otherwise, we may need to pad if the image is too small. Oh joy... + new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1])) + if isinstance(image, np.ndarray): + new_image = np.zeros_like(image, shape=new_shape) + elif is_torch_tensor(image): + new_image = image.new_zeros(new_shape) + + top_pad = (new_shape[-2] - image_shape[0]) // 2 + bottom_pad = top_pad + image_shape[0] + left_pad = (new_shape[-1] - image_shape[1]) // 2 + right_pad = left_pad + image_shape[1] + new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image + + top += top_pad + bottom += top_pad + left += left_pad + right += left_pad + + new_image = new_image[ + ..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right) + ] + + return new_image + + def flip_channel_order(self, image): + """ + Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of + `image` to a NumPy array if it's a PIL Image. + + Args: + image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): + The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should + be first. + """ + self._ensure_format_supported(image) + + if isinstance(image, PIL.Image.Image): + image = self.to_numpy_array(image) + + return image[::-1, :, :] + + def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None): + """ + Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees + counter clockwise around its centre. + + Args: + image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): + The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before + rotating. + + Returns: + image: A rotated `PIL.Image.Image`. + """ + resample = resample if resample is not None else PIL.Image.NEAREST + + self._ensure_format_supported(image) + + if not isinstance(image, PIL.Image.Image): + image = self.to_pil_image(image) + + return image.rotate( + angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor + ) + + +def validate_annotations( + annotation_format: AnnotationFormat, + supported_annotation_formats: tuple[AnnotationFormat, ...], + annotations: list[dict], +) -> None: + if annotation_format not in supported_annotation_formats: + raise ValueError(f"Unsupported annotation format: {format} must be one of {supported_annotation_formats}") + + if annotation_format is AnnotationFormat.COCO_DETECTION: + if not valid_coco_detection_annotations(annotations): + raise ValueError( + "Invalid COCO detection annotations. Annotations must a dict (single image) or list of dicts " + "(batch of images) with the following keys: `image_id` and `annotations`, with the latter " + "being a list of annotations in the COCO format." + ) + + if annotation_format is AnnotationFormat.COCO_PANOPTIC: + if not valid_coco_panoptic_annotations(annotations): + raise ValueError( + "Invalid COCO panoptic annotations. Annotations must a dict (single image) or list of dicts " + "(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with " + "the latter being a list of annotations in the COCO format." + ) + + +def validate_kwargs(valid_processor_keys: list[str], captured_kwargs: list[str]): + unused_keys = set(captured_kwargs).difference(set(valid_processor_keys)) + if unused_keys: + unused_key_str = ", ".join(unused_keys) + # TODO raise a warning here instead of simply logging? + logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.") + + +@dataclass(frozen=True) +class SizeDict: + """ + Hashable dictionary to store image size information. + """ + + height: Optional[int] = None + width: Optional[int] = None + longest_edge: Optional[int] = None + shortest_edge: Optional[int] = None + max_height: Optional[int] = None + max_width: Optional[int] = None + + def __getitem__(self, key): + if hasattr(self, key): + return getattr(self, key) + raise KeyError(f"Key {key} not found in SizeDict.") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/keras_callbacks.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/keras_callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..ab7fc4615b473d59c903260a8c1ec80b24f4af7b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/keras_callbacks.py @@ -0,0 +1,413 @@ +import logging +import os +from pathlib import Path +from time import sleep +from typing import Callable, Optional, Union + +import numpy as np +import tensorflow as tf +from huggingface_hub import Repository, create_repo +from packaging.version import parse + +from . import IntervalStrategy, PreTrainedTokenizerBase +from .modelcard import TrainingSummary +from .modeling_tf_utils import keras + + +logger = logging.getLogger(__name__) + + +class KerasMetricCallback(keras.callbacks.Callback): + """ + Callback to compute metrics at the end of every epoch. Unlike normal Keras metrics, these do not need to be + compilable by TF. It is particularly useful for common NLP metrics like BLEU and ROUGE that require string + operations or generation loops that cannot be compiled. Predictions (or generations) will be computed on the + `eval_dataset` before being passed to the `metric_fn` in `np.ndarray` format. The `metric_fn` should compute + metrics and return a dict mapping metric names to metric values. + + We provide an example of a suitable metric_fn that computes ROUGE scores for a summarization model below. Note that + this example skips some post-processing for readability and simplicity, and should probably not be used as-is! + + ```py + from datasets import load_metric + + rouge_metric = load_metric("rouge") + + + def rouge_fn(predictions, labels): + decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + result = rouge_metric.compute(predictions=decoded_predictions, references=decoded_labels) + return {key: value.mid.fmeasure * 100 for key, value in result.items()} + ``` + + The above function will return a dict containing values which will be logged like any other Keras metric: + + ``` + {'rouge1': 37.4199, 'rouge2': 13.9768, 'rougeL': 34.361, 'rougeLsum': 35.0781 + ``` + + Args: + metric_fn (`Callable`): + Metric function provided by the user. It will be called with two arguments - `predictions` and `labels`. + These contain the model's outputs and matching labels from the dataset. It should return a dict mapping + metric names to numerical values. + eval_dataset (`tf.data.Dataset` or `dict` or `tuple` or `np.ndarray` or `tf.Tensor`): + Validation data to be used to generate predictions for the `metric_fn`. + output_cols (`list[str], *optional*): + A list of columns to be retained from the model output as the predictions. Defaults to all. + label_cols ('`list[str]`, *optional*'): + A list of columns to be retained from the input dataset as the labels. Will be autodetected if this is not + supplied. + batch_size (`int`, *optional*): + Batch size. Only used when the data is not a pre-batched `tf.data.Dataset`. + predict_with_generate (`bool`, *optional*, defaults to `False`): + Whether we should use `model.generate()` to get outputs for the model. + use_xla_generation (`bool`, *optional*, defaults to `False`): + If we're generating, whether to compile model generation with XLA. This can massively increase the speed of + generation (up to 100X speedup) but will require a new XLA compilation for each input shape. When using XLA + generation, it's a good idea to pad your inputs to the same size, or to use the `pad_to_multiple_of` + argument in your `tokenizer` or `DataCollator`, which will reduce the number of unique input shapes and + save a lot of compilation time. This option has no effect is `predict_with_generate` is `False`. + generate_kwargs (`dict`, *optional*): + Keyword arguments to pass to `model.generate()` when generating. Has no effect if `predict_with_generate` + is `False`. + + """ + + def __init__( + self, + metric_fn: Callable, + eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict], + output_cols: Optional[list[str]] = None, + label_cols: Optional[list[str]] = None, + batch_size: Optional[int] = None, + predict_with_generate: bool = False, + use_xla_generation: bool = False, + generate_kwargs: Optional[dict] = None, + ): + super().__init__() + self.metric_fn = metric_fn + self.batch_size = batch_size + if not isinstance(eval_dataset, tf.data.Dataset): + if batch_size is None: + raise ValueError( + "When passing data to KerasMetricCallback that is not a pre-batched tf.data.Dataset " + "the batch_size argument must be set." + ) + # Wrap a tf.data.Dataset around it + eval_dataset = tf.data.Dataset.from_tensor_slices(eval_dataset).batch(batch_size, drop_remainder=False) + self.eval_dataset = eval_dataset + self.predict_with_generate = predict_with_generate + self.output_cols = output_cols + + # This next block attempts to parse out which elements of the dataset should be appended to the labels list + # that is passed to the metric_fn + if isinstance(eval_dataset.element_spec, tuple) and len(eval_dataset.element_spec) == 2: + input_spec, label_spec = eval_dataset.element_spec + else: + input_spec = eval_dataset.element_spec + label_spec = None + if label_cols is not None: + for label in label_cols: + if label not in input_spec: + raise ValueError(f"Label {label} is in label_cols but could not be found in the dataset inputs!") + self.label_cols = label_cols + self.use_keras_label = False + elif label_spec is not None: + # If the dataset inputs are split into a 2-tuple of inputs and labels, + # assume the second element is the labels + self.label_cols = None + self.use_keras_label = True + elif "labels" in input_spec: + self.label_cols = ["labels"] + self.use_keras_label = False + logging.warning("No label_cols specified for KerasMetricCallback, assuming you want the 'labels' key.") + elif "start_positions" in input_spec and "end_positions" in input_spec: + self.label_cols = ["start_positions", "end_positions"] + self.use_keras_label = False + logging.warning( + "No label_cols specified for KerasMetricCallback, assuming you want the " + "start_positions and end_positions keys." + ) + else: + raise ValueError("Could not autodetect label_cols for KerasMetricCallback, please specify them!") + if parse(tf.__version__) < parse("2.7"): + logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!") + + self.use_xla_generation = use_xla_generation + self.generate_kwargs = {} if generate_kwargs is None else generate_kwargs + + self.generation_function = None + + @staticmethod + def _concatenate_batches(batches, padding_index=-100): + # If all batches are unidimensional or same length, do a simple concatenation + if batches[0].ndim == 1 or all(batch.shape[1] == batches[0].shape[1] for batch in batches): + return np.concatenate(batches, axis=0) + + # Welp, they're not the same length. Let's do some padding + max_len = max([batch.shape[1] for batch in batches]) + num_samples = sum([batch.shape[0] for batch in batches]) + output = np.full_like( + batches[0], fill_value=padding_index, shape=[num_samples, max_len] + list(batches[0].shape[2:]) + ) + # i keeps track of which part of the concatenated array we're writing the next batch to + i = 0 + for batch in batches: + output[i : i + len(batch), : batch.shape[1]] = batch + i += len(batch) + return output + + def _postprocess_predictions_or_labels(self, inputs): + if isinstance(inputs[0], dict): + outputs = {} + for key in inputs[0]: + outputs[key] = self._concatenate_batches([batch[key] for batch in inputs]) + # If it's a dict with only one key, just return the array + if len(outputs) == 1: + outputs = list(outputs.values())[0] + elif isinstance(inputs[0], (tuple, list)): + outputs = [] + for input_list in zip(*inputs): + outputs.append(self._concatenate_batches(input_list)) + if len(outputs) == 1: + outputs = outputs[0] # If it's a list with only one element, just return the array + elif isinstance(inputs[0], np.ndarray): + outputs = self._concatenate_batches(inputs) + elif isinstance(inputs[0], tf.Tensor): + outputs = self._concatenate_batches([tensor.numpy() for tensor in inputs]) + else: + raise TypeError(f"Couldn't handle batch of type {type(inputs[0])}!") + return outputs + + def on_epoch_end(self, epoch, logs=None): + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + main_input_name = None + if self.predict_with_generate: + # This dense conditional recognizes the case where we have an encoder-decoder model, but + # avoids getting tangled up when we just have a model with a layer called 'encoder' + if hasattr(self.model, "encoder") and hasattr(self.model.encoder, "main_input_name"): + main_input_name = self.model.encoder.main_input_name + else: + main_input_name = getattr(self.model, "main_input_name", "input_ids") + + if self.use_xla_generation and self.generation_function is None: + + def generation_function(inputs, attention_mask): + return self.model.generate(inputs, attention_mask=attention_mask, **self.generate_kwargs) + + self.generation_function = tf.function(generation_function, jit_compile=True) + + prediction_list = [] + label_list = [] + + # The whole predict/generate loop is handled inside this method + for batch in self.eval_dataset: + if isinstance(batch, tuple): + batch, labels = batch + else: + labels = None + if self.predict_with_generate: + if isinstance(batch, dict): + generation_inputs = batch[main_input_name] + attention_mask = batch.get("attention_mask", None) + else: + generation_inputs = batch + attention_mask = None + if self.use_xla_generation: + predictions = self.generation_function(generation_inputs, attention_mask=attention_mask) + else: + predictions = self.model.generate( + generation_inputs, attention_mask=attention_mask, **self.generate_kwargs + ) + else: + predictions = self.model.predict_on_batch(batch) + if isinstance(predictions, dict): + # This converts any dict-subclass to a regular dict + # Keras REALLY doesn't like it when we pass around a BatchEncoding or other derived class + predictions = dict(predictions) + if self.output_cols is not None: + predictions = {key: predictions[key] for key in self.output_cols} + else: + predictions = { + key: val for key, val in predictions.items() if key not in ignore_keys + ["loss"] + } + prediction_list.append(predictions) + if not self.use_keras_label: + labels = {key: batch[key].numpy() for key in self.label_cols} + elif isinstance(labels, dict): + labels = {key: array.numpy() for key, array in labels.items()} + elif isinstance(labels, (list, tuple)): + labels = [array.numpy() for array in labels] + elif isinstance(labels, tf.Tensor): + labels = labels.numpy() + else: + raise TypeError(f"Confused by labels of type {type(labels)}") + label_list.append(labels) + + all_preds = self._postprocess_predictions_or_labels(prediction_list) + all_labels = self._postprocess_predictions_or_labels(label_list) + + metric_output = self.metric_fn((all_preds, all_labels)) + if not isinstance(metric_output, dict): + raise TypeError( + f"metric_fn should return a dict mapping metric names to values but instead returned {metric_output}" + ) + # This is the critical bit - Keras passes a dict containing the loss and standard metric values for this epoch + # in the logs argument. Ordinarily, this is so the callback can read them, but in this case we write a bunch of + # new keys in there, which will then get read by the History callback and treated like any other metric value. + # I promise that I have it in writing from Chollet that this is okay. + logs.update(metric_output) + + +class PushToHubCallback(keras.callbacks.Callback): + """ + Callback that will save and push the model to the Hub regularly. By default, it pushes once per epoch, but this can + be changed with the `save_strategy` argument. Pushed models can be accessed like any other model on the hub, such + as with the `from_pretrained` method. + + ```py + from transformers.keras_callbacks import PushToHubCallback + + push_to_hub_callback = PushToHubCallback( + output_dir="./model_save", + tokenizer=tokenizer, + hub_model_id="gpt5-7xlarge", + ) + + model.fit(train_dataset, callbacks=[push_to_hub_callback]) + ``` + + Args: + output_dir (`str`): + The output directory where the model predictions and checkpoints will be written and synced with the + repository on the Hub. + save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"epoch"`): + The checkpoint save strategy to adopt during training. Possible values are: + + - `"no"`: Save is done at the end of training. + - `"epoch"`: Save is done at the end of each epoch. + - `"steps"`: Save is done every `save_steps` + save_steps (`int`, *optional*): + The number of steps between saves when using the "steps" `save_strategy`. + tokenizer (`PreTrainedTokenizerBase`, *optional*): + The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights. + hub_model_id (`str`, *optional*): + The name of the repository to keep in sync with the local `output_dir`. It can be a simple model ID in + which case the model will be pushed in your namespace. Otherwise it should be the whole repository name, + for instance `"user_name/model"`, which allows you to push to an organization you are a member of with + `"organization_name/model"`. + + Will default to the name of `output_dir`. + hub_token (`str`, *optional*): + The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with + `hf auth login`. + checkpoint (`bool`, *optional*, defaults to `False`): + Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be + resumed. Only usable when `save_strategy` is `"epoch"`. + """ + + def __init__( + self, + output_dir: Union[str, Path], + save_strategy: Union[str, IntervalStrategy] = "epoch", + save_steps: Optional[int] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + hub_model_id: Optional[str] = None, + hub_token: Optional[str] = None, + checkpoint: bool = False, + **model_card_args, + ): + super().__init__() + if checkpoint and save_strategy != "epoch": + raise ValueError("Cannot save checkpoints when save_strategy is not 'epoch'!") + if isinstance(save_strategy, str): + save_strategy = IntervalStrategy(save_strategy.lower()) + self.save_strategy = save_strategy + if self.save_strategy == IntervalStrategy.STEPS and (not isinstance(save_steps, int) or save_steps <= 0): + raise ValueError("Please supply a positive integer argument for save_steps when save_strategy == 'steps'!") + self.save_steps = save_steps + output_dir = Path(output_dir) + + # Create repo and retrieve repo_id + if hub_model_id is None: + hub_model_id = output_dir.absolute().name + self.hub_model_id = create_repo(repo_id=hub_model_id, exist_ok=True, token=hub_token).repo_id + + self.output_dir = output_dir + self.repo = Repository(str(self.output_dir), clone_from=self.hub_model_id, token=hub_token) + + self.tokenizer = tokenizer + self.last_job = None + self.checkpoint = checkpoint + self.training_history = None + self.model_card_args = model_card_args + + def on_train_begin(self, logs=None): + # Although we can access model.history, we have no guarantees that the History callback will fire before this + # one, so we keep track of it here too + self.training_history = [] + + def on_train_batch_end(self, batch, logs=None): + if self.save_strategy == IntervalStrategy.STEPS and (batch + 1) % self.save_steps == 0: + if self.last_job is not None and not self.last_job.is_done: + return # The last upload is still running, don't start another + self.model.save_pretrained(self.output_dir) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(self.output_dir) + _, self.last_job = self.repo.push_to_hub( + commit_message=f"Training in progress steps {batch}", blocking=False + ) + + def on_epoch_end(self, epoch, logs=None): + logs = logs.copy() # Don't accidentally write things that Keras will read later + if "epoch" not in logs: + logs["epoch"] = epoch + self.training_history.append(logs) + if self.save_strategy == IntervalStrategy.EPOCH: + if self.last_job is not None and not self.last_job.is_done: + return # The last upload is still running, don't start another + self.model.save_pretrained(self.output_dir) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(self.output_dir) + if self.checkpoint: + checkpoint_dir = os.path.join(self.output_dir, "checkpoint") + self.model._save_checkpoint(checkpoint_dir, epoch) + train_summary = TrainingSummary.from_keras( + model=self.model, + model_name=self.hub_model_id, + keras_history=self.training_history, + **self.model_card_args, + ) + model_card = train_summary.to_model_card() + with (self.output_dir / "README.md").open("w") as f: + f.write(model_card) + _, self.last_job = self.repo.push_to_hub( + commit_message=f"Training in progress epoch {epoch}", blocking=False + ) + + def on_train_end(self, logs=None): + # Makes sure the latest version of the model is uploaded + if self.last_job is not None and not self.last_job.is_done: + logging.info("Pushing the last epoch to the Hub, this may take a while...") + while not self.last_job.is_done: + sleep(1) + else: + self.model.save_pretrained(self.output_dir) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(self.output_dir) + train_summary = TrainingSummary.from_keras( + model=self.model, + model_name=self.hub_model_id, + keras_history=self.training_history, + **self.model_card_args, + ) + model_card = train_summary.to_model_card() + with (self.output_dir / "README.md").open("w") as f: + f.write(model_card) + self.repo.push_to_hub(commit_message="End of training", blocking=True) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/model_debugging_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/model_debugging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2c7b47c04fd508319e9f13511e5621260108dc2b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/model_debugging_utils.py @@ -0,0 +1,456 @@ +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import json +import os +import re +from contextlib import contextmanager, redirect_stdout +from io import StringIO +from typing import Optional + +from .utils import logging +from .utils.import_utils import is_torch_available, requires + + +if is_torch_available(): + import torch + from safetensors.torch import save_file + + _torch_distributed_available = False + # Note to code inspectors: this toolbox is intended for people who add models to `transformers`. + if torch.distributed.is_available(): + import torch.distributed.tensor + + _torch_distributed_available = True +else: + _torch_distributed_available = False + + +logger = logging.get_logger(__name__) + + +def _is_rank_zero(): + """Return True if rank=0 or we aren't running distributed.""" + if not (_torch_distributed_available and torch.distributed.is_initialized()): + return True + return torch.distributed.get_rank() == 0 + + +MEMORY_ADDRESS_REGEX = re.compile(r"object at 0x[0-9A-Fa-f]+") + + +def _sanitize_repr_for_diff(x_str: str) -> str: + """ + Replace memory addresses in an object's repr with a stable placeholder + so that beautiful JSON diffs won't be ruined by ephemeral addresses. + """ + return MEMORY_ADDRESS_REGEX.sub("object at 0xXXXXXXXX", x_str) + + +def _dtensor_repr(x): + """Return a stable string representation for a DTensor-like object.""" + if _is_rank_zero(): + return f"DTensor (rank0) -> {repr(x._local_tensor)}" + return "DTensor(non-rank0)" + + +def _serialize_tensor_like_io( + value, debug_path: Optional[str] = None, use_repr: bool = True, path_to_value: Optional[str] = None +): + """ + Converts Tensors and DTensors to a JSON-serializable dictionary representation. + + Args: + value: Any Python object, often including torch Tensors, lists, dicts, etc. + debug_path (`str`, *optional*, defaults to `None`): Directory to dump debug JSON and SafeTensors files. + use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensor as the + `value` property in the asscoiated FULL_TENSORS.json file, or to store the full tensors in separate + SafeTensors file and store the relative path to that file in the `value` property in the dictionary. + path_to_value (`str`, *optional*, defaults to `None`): The file name for the SafeTensors file holding the full + tensor value if `use_repr=False`. + + Returns: + A nested Python structure (list, dict, or sanitized string) that is safe to json.dump. + """ + torch.set_printoptions(sci_mode=True) + + if use_repr: + value_out = _repr_to_list(value) + elif path_to_value: + if not path_to_value.endswith(".safetensors"): + path_to_value += ".safetensors" + + filepath = os.path.join(debug_path, path_to_value) if debug_path else path_to_value + save_file({"data": value.contiguous().detach().cpu()}, filepath) + value_out = f"./{path_to_value}" + else: + raise ValueError(f"{use_repr=} and {path_to_value=} cannot both be falsy.") + + out = { + "shape": repr(value.shape), + "dtype": repr(value.dtype), + "value": value_out, + } + if value.dtype in {torch.float16, torch.float32, torch.bfloat16}: + out.update( + { + "mean": _sanitize_repr_for_diff(repr(value.mean())), + "std": _sanitize_repr_for_diff(repr(value.std())), + "min": _sanitize_repr_for_diff(repr(value.min())), + "max": _sanitize_repr_for_diff(repr(value.max())), + } + ) + return out + + +def _serialize_io(value, debug_path: Optional[str] = None, use_repr: bool = True, path_to_value: Optional[str] = None): + """ + Recursively build a JSON-serializable Python structure from `value`. + Tensors and DTensors become either sanitized repr strings, or are saved to disk as SafeTensors files and their + relative paths are recorded in the returned Python structure. + Lists/tuples/dicts are recursed into. + All memory addresses are replaced with a stable placeholder. + + Args: + value: Any Python object, often including torch Tensors, lists, dicts, etc. + debug_path (`str`, *optional*, defaults to `None`): Directory to dump debug JSON and SafeTensors files. + use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensors as the + `value` property in the asscoiated FULL_TENSORS.json file, or to store full tensors in separate SafeTensors + files and store the relative path to that file in the `value` property. + path_to_value (`str`, *optional*, defaults to `None`): The file name for the SafeTensors file holding the full + tensor value if `use_repr=False`. + + Returns: + A nested Python structure (list, dict, or sanitized string) that is safe to json.dump. + """ + if isinstance(value, (list, tuple)): + return [ + _serialize_io(v, debug_path=debug_path, use_repr=use_repr, path_to_value=f"{path_to_value}_{i}") + for i, v in enumerate(value) + ] + + if isinstance(value, dict): + return { + k: _serialize_io(v, debug_path=debug_path, use_repr=use_repr, path_to_value=f"{path_to_value}_{k}") + for k, v in value.items() + } + + if hasattr(value, "_local_tensor"): + return _serialize_tensor_like_io( + value._local_tensor, debug_path=debug_path, use_repr=use_repr, path_to_value=path_to_value + ) + + if isinstance(value, torch.Tensor): + return _serialize_tensor_like_io(value, debug_path=debug_path, use_repr=use_repr, path_to_value=path_to_value) + + return _sanitize_repr_for_diff(repr(value)) + + +def _repr_to_list(value: torch.Tensor): + """ + Converts a tensor into a sanitized multi-line string representation. + + Args: + value (`torch.Tensor`): The tensor to represent. + + Returns: + `list[str]`: List of string lines representing the tensor. + """ + torch.set_printoptions(sci_mode=True, linewidth=120) + with StringIO() as buf, redirect_stdout(buf): + print(value) # to redirected stdout to avoid line splits + raw = buf.getvalue() + return _sanitize_repr_for_diff(raw).splitlines() + + +def prune_outputs_if_children(node): + # if there are children, remove this node's "outputs" + # so we only see outputs at the leaf level + if node.get("children"): + node.pop("outputs", None) + for child in node["children"]: + prune_outputs_if_children(child) + + +LAYER_SUFFIX_RE = re.compile(r"(.*)\.(\d+)$") # should be generic enough, ends with a number + + +def is_layer_block(node): + """ + Checks whether a node represents a layer block with submodules. + + Args: + node (`dict`): A node from the call tree. + + Returns: + `bool`: Whether the node is a layer block. + """ + match = LAYER_SUFFIX_RE.match(node.get("module_path", "")) + if not match or not node.get("children"): + return False + number = match.group(2) + return any(f".{number}." in child.get("module_path", "") for child in node["children"]) + + +def prune_intermediate_layers(node): + """ + Recursively removes intermediate layers from the tree to improve readability. + Keeps at least the first and last layers if many consecutive layers are present. + + Args: + node (`dict`): The root or subnode to prune recursively. + """ + if not node.get("children"): + return + layer_blocks = [(i, child) for i, child in enumerate(node["children"]) if is_layer_block(child)] + + if len(layer_blocks) > 2: + to_remove = [i for i, _ in layer_blocks[1:-1]] + node["children"] = [child for i, child in enumerate(node["children"]) if i not in to_remove] + + for child in node["children"]: + prune_intermediate_layers(child) + + +def log_model_debug_trace(debug_path: Optional[str], model): + if debug_path: + try: + os.makedirs(debug_path, exist_ok=True) + base = os.path.join(debug_path, model._debugger_module_dump_name + "_debug_tree") + except Exception as e: + raise ValueError(f"Unexpected or existing debug_path={debug_path}.") from e + else: + base = model._debugger_module_dump_name + "_debug_tree" + + logger.info(f"Writing model trace at {base}.json") + full_path = base + "_FULL_TENSORS.json" + summary_path = base + "_SUMMARY.json" + + prune_outputs_if_children(model._call_tree) + + with open(full_path, "w") as f: + json.dump(model._call_tree, f, indent=2) + + # summary-only version for readability - traversing the tree again #TODO optimize? + def strip_values(node): + def clean(val): + if isinstance(val, dict): + val.pop("value", None) + for v in val.values(): + clean(v) + elif isinstance(val, list): + for item in val: + clean(item) + + clean(node.get("inputs", {})) + clean(node.get("outputs", {})) + + for child in node.get("children", []): + strip_values(child) + + tree_copy = json.loads(json.dumps(model._call_tree)) # deep copy + strip_values(tree_copy) + + with open(summary_path, "w") as f: + json.dump(tree_copy, f, indent=2) + + +def _attach_debugger_logic( + model, + debug_path: str = ".", + do_prune_layers: bool = True, + use_repr: bool = True, +): + """ + Attaches a debugging wrapper to every module in the model. + + This records structured inputs and outputs during the forward pass into a call tree. + + Args: + model (`PreTrainedModel`, `nn.Module`): Model to wrap. + debug_path (`str`): Optional directory to dump debug JSON files. + do_prune_layers (`bool`, *optional*, defaults to `True`): Whether to prune intermediate layers. + use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensors as the + `value` property in the associated FULL_TENSORS.json file, or to store full tensors in separate SafeTensors + files and store the relative path to that file in the `value` property. + """ + class_name = model.__class__.__name__ + + # Prepare data structures on the model object + model._call_tree = {"module_path": class_name, "inputs": None, "outputs": None, "children": []} + model._debugger_model_call_stack = [] + model._debugger_module_dump_name = class_name # used for final JSON filename + + if debug_path: + try: + os.makedirs(debug_path, exist_ok=True) + except Exception as e: + raise ValueError(f"Unexpected or existing debug_path={debug_path}.") from e + + def wrap_forward(module, full_path): + orig_forward = module.forward + + @functools.wraps(orig_forward) + def wrapped_forward(*inps, **kws): + if _is_rank_zero(): + dict_inputs = {"args": inps, "kwargs": kws} + dict_inputs = {k: dict_inputs[k] for k in dict_inputs if len(dict_inputs[k]) > 0} + node = { + "module_path": full_path, + "inputs": _serialize_io( + dict_inputs, + debug_path=debug_path, + use_repr=use_repr, + path_to_value=f"{full_path}_inputs", + ), + "outputs": None, + "children": [], + } + model._debugger_model_call_stack.append(node) + with torch.no_grad(): + out = orig_forward(*inps, **kws) + + if _is_rank_zero(): + if sum(1 for _ in module.named_children()) > 0: + node["outputs"] = None + else: + node["outputs"] = _serialize_io( + out, + debug_path=debug_path, + use_repr=use_repr, + path_to_value=f"{full_path}_outputs", + ) + + finished = model._debugger_model_call_stack.pop() + # prune empty vertices here as well (mostly empty children nodes) + if not finished["children"]: + finished.pop("children") + + if model._debugger_model_call_stack: + model._debugger_model_call_stack[-1]["children"].append(finished) + return out + + module.forward = wrapped_forward + + # wrap all submodules + for name, submodule in model.named_modules(): + if name == "": + continue + wrap_forward(submodule, f"{class_name}.{name}") + + # wrap top-level forward + real_top_forward = model.forward + + @functools.wraps(real_top_forward) + def top_wrapped_forward(*inps, **kws): + if _is_rank_zero(): + top_node = { + "module_path": f"{class_name} (top-level)", + "inputs": _serialize_io( + {"args": inps, "kwargs": kws}, + debug_path=debug_path, + use_repr=use_repr, + path_to_value=f"{class_name}_inputs", + ), + "outputs": None, + "children": [], + } + model._debugger_model_call_stack.append(top_node) + + out = real_top_forward(*inps, **kws) + if _is_rank_zero() and model._debugger_model_call_stack: + top_node["outputs"] = _serialize_io( + out, + debug_path=debug_path, + use_repr=use_repr, + path_to_value=f"{class_name}_outputs", + ) + finished = model._debugger_model_call_stack.pop() + model._call_tree["inputs"] = finished["inputs"] + model._call_tree["outputs"] = finished["outputs"] + model._call_tree["children"] = finished["children"] + # prune empty stuff for visibility + [model._call_tree.pop(k, None) for k in list(model._call_tree.keys()) if not model._call_tree[k]] + + # prune layers that are not 0 or last + if do_prune_layers: + prune_intermediate_layers(model._call_tree) + # Write final JSON trace here + log_model_debug_trace(debug_path=debug_path, model=model) + return out + + model.forward = top_wrapped_forward + + +@requires(backends=("torch",)) +@contextmanager +def model_addition_debugger_context( + model, + debug_path: Optional[str] = None, + do_prune_layers: bool = True, + use_repr: bool = True, +): + """ + # Model addition debugger - context manager for model adders + This context manager is a power user tool intended for model adders. + + It tracks all forward calls within a model forward and logs a slice of each input and output on a nested JSON file. + If `use_repr=True` (the default), the JSON file will record a `repr()`-ized version of the tensors as a list of + strings. If `use_repr=False`, the full tensors will be stored in separate SafeTensors files and the JSON file will + provide a relative path to that file. + + To note, this context manager enforces `torch.no_grad()`. + + ## Usage + + add the context manager to a model to debug + + ```python + import torch + + from PIL import Image + from transformers import LlavaProcessor, LlavaForConditionalGeneration, model_addition_debugger_context + + torch.random.manual_seed(673) + + # load pretrained model and processor + model_id = "llava-hf/llava-1.5-7b-hf" + processor = LlavaProcessor.from_pretrained(model_id) + model = LlavaForConditionalGeneration.from_pretrained(model_id) + + # create random image input + random_image = Image.fromarray(torch.randint(0, 256, (224, 224, 3), dtype=torch.uint8).numpy()) + + # prompt + prompt = "Describe this image." + + # process inputs + inputs = processor(text=prompt, images=random_image, return_tensors="pt") + + # call forward method (not .generate!) + with model_addition_debugger_context(model, debug_path="Your_debug_path", do_prune_layers=False): + output = model.forward(**inputs) + ``` + + """ + orig_forwards = {m: m.forward for _, m in model.named_modules()} + orig_forwards[model] = model.forward + _attach_debugger_logic(model, debug_path, do_prune_layers, use_repr) + try: + yield model + finally: + for module_instance, forward_method in orig_forwards.items(): + module_instance.forward = forward_method diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modelcard.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modelcard.py new file mode 100644 index 0000000000000000000000000000000000000000..dd3a0b4017334752ecec35a634ead9eae96d462b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modelcard.py @@ -0,0 +1,914 @@ +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Configuration base class and utilities.""" + +import copy +import json +import os +import warnings +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional, Union + +import requests +import yaml +from huggingface_hub import model_info +from huggingface_hub.errors import OfflineModeIsEnabled +from huggingface_hub.utils import HFValidationError + +from . import __version__ +from .models.auto.modeling_auto import ( + MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_CTC_MAPPING_NAMES, + MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, + MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, + MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, + MODEL_FOR_MASKED_LM_MAPPING_NAMES, + MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, + MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, + MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, + MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES, + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, + MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES, +) +from .training_args import ParallelMode +from .utils import ( + MODEL_CARD_NAME, + cached_file, + is_datasets_available, + is_offline_mode, + is_tf_available, + is_tokenizers_available, + is_torch_available, + logging, +) + + +TASK_MAPPING = { + "text-generation": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, + "image-segmentation": MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, + "fill-mask": MODEL_FOR_MASKED_LM_MAPPING_NAMES, + "object-detection": MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, + "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, + "text2text-generation": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, + "text-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, + "table-question-answering": MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES, + "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, + "audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, + "automatic-speech-recognition": {**MODEL_FOR_CTC_MAPPING_NAMES, **MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES}, + "zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES, + "image-text-to-text": MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, +} + +logger = logging.get_logger(__name__) + + +class ModelCard: + r""" + Structured Model Card class. Store model card as well as methods for loading/downloading/saving model cards. + + Please read the following paper for details and explanation on the sections: "Model Cards for Model Reporting" by + Margaret Mitchell, Simone Wu, Andrew Zaldivar, Parker Barnes, Lucy Vasserman, Ben Hutchinson, Elena Spitzer, + Inioluwa Deborah Raji and Timnit Gebru for the proposal behind model cards. Link: https://huggingface.co/papers/1810.03993 + + Note: A model card can be loaded and saved to disk. + """ + + def __init__(self, **kwargs): + warnings.warn( + "The class `ModelCard` is deprecated and will be removed in version 5 of Transformers", FutureWarning + ) + # Recommended attributes from https://huggingface.co/papers/1810.03993 (see papers) + self.model_details = kwargs.pop("model_details", {}) + self.intended_use = kwargs.pop("intended_use", {}) + self.factors = kwargs.pop("factors", {}) + self.metrics = kwargs.pop("metrics", {}) + self.evaluation_data = kwargs.pop("evaluation_data", {}) + self.training_data = kwargs.pop("training_data", {}) + self.quantitative_analyses = kwargs.pop("quantitative_analyses", {}) + self.ethical_considerations = kwargs.pop("ethical_considerations", {}) + self.caveats_and_recommendations = kwargs.pop("caveats_and_recommendations", {}) + + # Open additional attributes + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + def save_pretrained(self, save_directory_or_file): + """Save a model card object to the directory or file `save_directory_or_file`.""" + if os.path.isdir(save_directory_or_file): + # If we save using the predefined names, we can load using `from_pretrained` + output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME) + else: + output_model_card_file = save_directory_or_file + + self.to_json_file(output_model_card_file) + logger.info(f"Model card saved in {output_model_card_file}") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + r""" + Instantiate a [`ModelCard`] from a pre-trained model model card. + + Parameters: + pretrained_model_name_or_path: either: + + - a string, the *model id* of a pretrained model card hosted inside a model repo on huggingface.co. + - a path to a *directory* containing a model card file saved using the [`~ModelCard.save_pretrained`] + method, e.g.: `./my_model_directory/`. + - a path or url to a saved model card JSON *file*, e.g.: `./my_model_directory/modelcard.json`. + + cache_dir: (*optional*) string: + Path to a directory in which a downloaded pre-trained model card should be cached if the standard cache + should not be used. + + kwargs: (*optional*) dict: key/value pairs with which to update the ModelCard object after loading. + + - The values in kwargs of any keys which are model card attributes will be used to override the loaded + values. + - Behavior concerning key/value pairs whose keys are *not* model card attributes is controlled by the + *return_unused_kwargs* keyword parameter. + + proxies: (*optional*) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request. + + return_unused_kwargs: (*optional*) bool: + + - If False, then this function returns just the final model card object. + - If True, then this functions returns a tuple *(model card, unused_kwargs)* where *unused_kwargs* is a + dictionary consisting of the key/value pairs whose keys are not model card attributes: ie the part of + kwargs which has not been used to update *ModelCard* and is otherwise ignored. + + Examples: + + ```python + # Download model card from huggingface.co and cache. + modelcard = ModelCard.from_pretrained("google-bert/bert-base-uncased") + # Model card was saved using *save_pretrained('./test/saved_model/')* + modelcard = ModelCard.from_pretrained("./test/saved_model/") + modelcard = ModelCard.from_pretrained("./test/saved_model/modelcard.json") + modelcard = ModelCard.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False) + ```""" + cache_dir = kwargs.pop("cache_dir", None) + proxies = kwargs.pop("proxies", None) + return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + from_pipeline = kwargs.pop("_from_pipeline", None) + + user_agent = {"file_type": "model_card"} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + is_local = os.path.isdir(pretrained_model_name_or_path) + if os.path.isfile(pretrained_model_name_or_path): + resolved_model_card_file = pretrained_model_name_or_path + is_local = True + else: + try: + # Load from URL or cache if already cached + resolved_model_card_file = cached_file( + pretrained_model_name_or_path, + filename=MODEL_CARD_NAME, + cache_dir=cache_dir, + proxies=proxies, + user_agent=user_agent, + ) + if is_local: + logger.info(f"loading model card file {resolved_model_card_file}") + else: + logger.info(f"loading model card file {MODEL_CARD_NAME} from cache at {resolved_model_card_file}") + # Load model card + modelcard = cls.from_json_file(resolved_model_card_file) + + except (OSError, json.JSONDecodeError): + # We fall back on creating an empty model card + modelcard = cls() + + # Update model card with kwargs if needed + to_remove = [] + for key, value in kwargs.items(): + if hasattr(modelcard, key): + setattr(modelcard, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + logger.info(f"Model card: {modelcard}") + if return_unused_kwargs: + return modelcard, kwargs + else: + return modelcard + + @classmethod + def from_dict(cls, json_object): + """Constructs a `ModelCard` from a Python dictionary of parameters.""" + return cls(**json_object) + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `ModelCard` from a json file of parameters.""" + with open(json_file, encoding="utf-8") as reader: + text = reader.read() + dict_obj = json.loads(text) + return cls(**dict_obj) + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path): + """Save this instance to a json file.""" + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + +AUTOGENERATED_TRAINER_COMMENT = """ + +""" + +AUTOGENERATED_KERAS_COMMENT = """ + +""" + + +TASK_TAG_TO_NAME_MAPPING = { + "fill-mask": "Masked Language Modeling", + "image-classification": "Image Classification", + "image-segmentation": "Image Segmentation", + "multiple-choice": "Multiple Choice", + "object-detection": "Object Detection", + "question-answering": "Question Answering", + "summarization": "Summarization", + "table-question-answering": "Table Question Answering", + "text-classification": "Text Classification", + "text-generation": "Causal Language Modeling", + "text2text-generation": "Sequence-to-sequence Language Modeling", + "token-classification": "Token Classification", + "translation": "Translation", + "zero-shot-classification": "Zero Shot Classification", + "automatic-speech-recognition": "Automatic Speech Recognition", + "audio-classification": "Audio Classification", +} + + +METRIC_TAGS = [ + "accuracy", + "bleu", + "f1", + "matthews_correlation", + "pearsonr", + "precision", + "recall", + "rouge", + "sacrebleu", + "spearmanr", + "wer", +] + + +def _listify(obj): + if obj is None: + return [] + elif isinstance(obj, str): + return [obj] + else: + return obj + + +def _insert_values_as_list(metadata, name, values): + if values is None: + return metadata + if isinstance(values, str): + values = [values] + values = [v for v in values if v is not None] + if len(values) == 0: + return metadata + metadata[name] = values + return metadata + + +def infer_metric_tags_from_eval_results(eval_results): + if eval_results is None: + return {} + result = {} + for key in eval_results: + if key.lower().replace(" ", "_") in METRIC_TAGS: + result[key.lower().replace(" ", "_")] = key + elif key.lower() == "rouge1": + result["rouge"] = key + return result + + +def _insert_value(metadata, name, value): + if value is None: + return metadata + metadata[name] = value + return metadata + + +def is_hf_dataset(dataset): + if not is_datasets_available(): + return False + + from datasets import Dataset, IterableDataset + + return isinstance(dataset, (Dataset, IterableDataset)) + + +def _get_mapping_values(mapping): + result = [] + for v in mapping.values(): + if isinstance(v, (tuple, list)): + result += list(v) + else: + result.append(v) + return result + + +@dataclass +class TrainingSummary: + model_name: str + language: Optional[Union[str, list[str]]] = None + license: Optional[str] = None + tags: Optional[Union[str, list[str]]] = None + finetuned_from: Optional[str] = None + tasks: Optional[Union[str, list[str]]] = None + dataset: Optional[Union[str, list[str]]] = None + dataset_tags: Optional[Union[str, list[str]]] = None + dataset_args: Optional[Union[str, list[str]]] = None + dataset_metadata: Optional[dict[str, Any]] = None + eval_results: Optional[dict[str, float]] = None + eval_lines: Optional[list[str]] = None + hyperparameters: Optional[dict[str, Any]] = None + source: Optional[str] = "trainer" + + def __post_init__(self): + # Infer default license from the checkpoint used, if possible. + if ( + self.license is None + and not is_offline_mode() + and self.finetuned_from is not None + and len(self.finetuned_from) > 0 + ): + try: + info = model_info(self.finetuned_from) + for tag in info.tags: + if tag.startswith("license:"): + self.license = tag[8:] + except ( + requests.exceptions.HTTPError, + requests.exceptions.ConnectionError, + HFValidationError, + OfflineModeIsEnabled, + ): + pass + + def create_model_index(self, metric_mapping): + model_index = {"name": self.model_name} + + # Dataset mapping tag -> name + dataset_names = _listify(self.dataset) + dataset_tags = _listify(self.dataset_tags) + dataset_args = _listify(self.dataset_args) + dataset_metadata = _listify(self.dataset_metadata) + if len(dataset_args) < len(dataset_tags): + dataset_args = dataset_args + [None] * (len(dataset_tags) - len(dataset_args)) + dataset_mapping = dict(zip(dataset_tags, dataset_names)) + dataset_arg_mapping = dict(zip(dataset_tags, dataset_args)) + dataset_metadata_mapping = dict(zip(dataset_tags, dataset_metadata)) + + task_mapping = { + task: TASK_TAG_TO_NAME_MAPPING[task] for task in _listify(self.tasks) if task in TASK_TAG_TO_NAME_MAPPING + } + + model_index["results"] = [] + + if len(task_mapping) == 0 and len(dataset_mapping) == 0: + return [model_index] + if len(task_mapping) == 0: + task_mapping = {None: None} + if len(dataset_mapping) == 0: + dataset_mapping = {None: None} + + # One entry per dataset and per task + all_possibilities = [(task_tag, ds_tag) for task_tag in task_mapping for ds_tag in dataset_mapping] + for task_tag, ds_tag in all_possibilities: + result = {} + if task_tag is not None: + result["task"] = {"name": task_mapping[task_tag], "type": task_tag} + + if ds_tag is not None: + metadata = dataset_metadata_mapping.get(ds_tag, {}) + result["dataset"] = { + "name": dataset_mapping[ds_tag], + "type": ds_tag, + **metadata, + } + if dataset_arg_mapping[ds_tag] is not None: + result["dataset"]["args"] = dataset_arg_mapping[ds_tag] + + if len(metric_mapping) > 0: + result["metrics"] = [] + for metric_tag, metric_name in metric_mapping.items(): + result["metrics"].append( + { + "name": metric_name, + "type": metric_tag, + "value": self.eval_results[metric_name], + } + ) + + # Remove partial results to avoid the model card being rejected. + if "task" in result and "dataset" in result and "metrics" in result: + model_index["results"].append(result) + else: + logger.info(f"Dropping the following result as it does not have all the necessary fields:\n{result}") + + return [model_index] + + def create_metadata(self): + metric_mapping = infer_metric_tags_from_eval_results(self.eval_results) + + metadata = {} + metadata = _insert_value(metadata, "library_name", "transformers") + metadata = _insert_values_as_list(metadata, "language", self.language) + metadata = _insert_value(metadata, "license", self.license) + if self.finetuned_from is not None and isinstance(self.finetuned_from, str) and len(self.finetuned_from) > 0: + metadata = _insert_value(metadata, "base_model", self.finetuned_from) + metadata = _insert_values_as_list(metadata, "tags", self.tags) + metadata = _insert_values_as_list(metadata, "datasets", self.dataset_tags) + metadata = _insert_values_as_list(metadata, "metrics", list(metric_mapping.keys())) + metadata["model-index"] = self.create_model_index(metric_mapping) + + return metadata + + def to_model_card(self): + model_card = "" + + metadata = yaml.dump(self.create_metadata(), sort_keys=False) + if len(metadata) > 0: + model_card = f"---\n{metadata}---\n" + + # Now the model card for realsies. + if self.source == "trainer": + model_card += AUTOGENERATED_TRAINER_COMMENT + else: + model_card += AUTOGENERATED_KERAS_COMMENT + + model_card += f"\n# {self.model_name}\n\n" + + if self.finetuned_from is None: + model_card += "This model was trained from scratch on " + else: + model_card += ( + "This model is a fine-tuned version of" + f" [{self.finetuned_from}](https://huggingface.co/{self.finetuned_from}) on " + ) + + if self.dataset is None or (isinstance(self.dataset, list) and len(self.dataset) == 0): + model_card += "an unknown dataset." + else: + if isinstance(self.dataset, str): + model_card += f"the {self.dataset} dataset." + elif isinstance(self.dataset, (tuple, list)) and len(self.dataset) == 1: + model_card += f"the {self.dataset[0]} dataset." + else: + model_card += ( + ", ".join([f"the {ds}" for ds in self.dataset[:-1]]) + f" and the {self.dataset[-1]} datasets." + ) + + if self.eval_results is not None: + model_card += "\nIt achieves the following results on the evaluation set:\n" + model_card += "\n".join([f"- {name}: {_maybe_round(value)}" for name, value in self.eval_results.items()]) + model_card += "\n" + + model_card += "\n## Model description\n\nMore information needed\n" + model_card += "\n## Intended uses & limitations\n\nMore information needed\n" + model_card += "\n## Training and evaluation data\n\nMore information needed\n" + + model_card += "\n## Training procedure\n" + model_card += "\n### Training hyperparameters\n" + if self.hyperparameters is not None: + model_card += "\nThe following hyperparameters were used during training:\n" + model_card += "\n".join([f"- {name}: {value}" for name, value in self.hyperparameters.items()]) + model_card += "\n" + else: + model_card += "\nMore information needed\n" + + if self.eval_lines is not None: + model_card += "\n### Training results\n\n" + model_card += make_markdown_table(self.eval_lines) + model_card += "\n" + + model_card += "\n### Framework versions\n\n" + model_card += f"- Transformers {__version__}\n" + + if self.source == "trainer" and is_torch_available(): + import torch + + model_card += f"- Pytorch {torch.__version__}\n" + elif self.source == "keras" and is_tf_available(): + import tensorflow as tf + + model_card += f"- TensorFlow {tf.__version__}\n" + if is_datasets_available(): + import datasets + + model_card += f"- Datasets {datasets.__version__}\n" + if is_tokenizers_available(): + import tokenizers + + model_card += f"- Tokenizers {tokenizers.__version__}\n" + + return model_card + + @classmethod + def from_trainer( + cls, + trainer, + language=None, + license=None, + tags=None, + model_name=None, + finetuned_from=None, + tasks=None, + dataset_tags=None, + dataset_metadata=None, + dataset=None, + dataset_args=None, + ): + # Infer default from dataset + one_dataset = trainer.eval_dataset if trainer.eval_dataset is not None else trainer.train_dataset + if is_hf_dataset(one_dataset) and (dataset_tags is None or dataset_args is None or dataset_metadata is None): + default_tag = one_dataset.builder_name + # Those are not real datasets from the Hub so we exclude them. + if default_tag not in ["csv", "json", "pandas", "parquet", "text"]: + if dataset_metadata is None: + dataset_metadata = [{"config": one_dataset.config_name, "split": str(one_dataset.split)}] + if dataset_tags is None: + dataset_tags = [default_tag] + if dataset_args is None: + dataset_args = [one_dataset.config_name] + + if dataset is None and dataset_tags is not None: + dataset = dataset_tags + + # Infer default finetuned_from + if ( + finetuned_from is None + and hasattr(trainer.model.config, "_name_or_path") + and not os.path.isdir(trainer.model.config._name_or_path) + ): + finetuned_from = trainer.model.config._name_or_path + + # Infer default task tag: + if tasks is None: + model_class_name = trainer.model.__class__.__name__ + for task, mapping in TASK_MAPPING.items(): + if model_class_name in _get_mapping_values(mapping): + tasks = task + + if model_name is None: + model_name = Path(trainer.args.output_dir).name + if len(model_name) == 0: + model_name = finetuned_from + + # Add `generated_from_trainer` to the tags + if tags is None: + tags = ["generated_from_trainer"] + elif isinstance(tags, str) and tags != "generated_from_trainer": + tags = [tags, "generated_from_trainer"] + elif "generated_from_trainer" not in tags: + tags.append("generated_from_trainer") + + _, eval_lines, eval_results = parse_log_history(trainer.state.log_history) + hyperparameters = extract_hyperparameters_from_trainer(trainer) + + return cls( + language=language, + license=license, + tags=tags, + model_name=model_name, + finetuned_from=finetuned_from, + tasks=tasks, + dataset=dataset, + dataset_tags=dataset_tags, + dataset_args=dataset_args, + dataset_metadata=dataset_metadata, + eval_results=eval_results, + eval_lines=eval_lines, + hyperparameters=hyperparameters, + ) + + @classmethod + def from_keras( + cls, + model, + model_name, + keras_history=None, + language=None, + license=None, + tags=None, + finetuned_from=None, + tasks=None, + dataset_tags=None, + dataset=None, + dataset_args=None, + ): + # Infer default from dataset + if dataset is not None: + if is_hf_dataset(dataset) and (dataset_tags is None or dataset_args is None): + default_tag = dataset.builder_name + # Those are not real datasets from the Hub so we exclude them. + if default_tag not in ["csv", "json", "pandas", "parquet", "text"]: + if dataset_tags is None: + dataset_tags = [default_tag] + if dataset_args is None: + dataset_args = [dataset.config_name] + + if dataset is None and dataset_tags is not None: + dataset = dataset_tags + + # Infer default finetuned_from + if ( + finetuned_from is None + and hasattr(model.config, "_name_or_path") + and not os.path.isdir(model.config._name_or_path) + ): + finetuned_from = model.config._name_or_path + + # Infer default task tag: + if tasks is None: + model_class_name = model.__class__.__name__ + for task, mapping in TASK_MAPPING.items(): + if model_class_name in _get_mapping_values(mapping): + tasks = task + + # Add `generated_from_keras_callback` to the tags + if tags is None: + tags = ["generated_from_keras_callback"] + elif isinstance(tags, str) and tags != "generated_from_keras_callback": + tags = [tags, "generated_from_keras_callback"] + elif "generated_from_keras_callback" not in tags: + tags.append("generated_from_keras_callback") + + if keras_history is not None: + _, eval_lines, eval_results = parse_keras_history(keras_history) + else: + eval_lines = [] + eval_results = {} + hyperparameters = extract_hyperparameters_from_keras(model) + + return cls( + language=language, + license=license, + tags=tags, + model_name=model_name, + finetuned_from=finetuned_from, + tasks=tasks, + dataset_tags=dataset_tags, + dataset=dataset, + dataset_args=dataset_args, + eval_results=eval_results, + eval_lines=eval_lines, + hyperparameters=hyperparameters, + source="keras", + ) + + +def parse_keras_history(logs): + """ + Parse the `logs` of either a `keras.History` object returned by `model.fit()` or an accumulated logs `dict` + passed to the `PushToHubCallback`. Returns lines and logs compatible with those returned by `parse_log_history`. + """ + if hasattr(logs, "history"): + # This looks like a `History` object + if not hasattr(logs, "epoch"): + # This history looks empty, return empty results + return None, [], {} + logs.history["epoch"] = logs.epoch + logs = logs.history + else: + # Training logs is a list of dicts, let's invert it to a dict of lists to match a History object + logs = {log_key: [single_dict[log_key] for single_dict in logs] for log_key in logs[0]} + + lines = [] + for i in range(len(logs["epoch"])): + epoch_dict = {log_key: log_value_list[i] for log_key, log_value_list in logs.items()} + values = {} + for k, v in epoch_dict.items(): + if k.startswith("val_"): + k = "validation_" + k[4:] + elif k != "epoch": + k = "train_" + k + splits = k.split("_") + name = " ".join([part.capitalize() for part in splits]) + values[name] = v + lines.append(values) + + eval_results = lines[-1] + + return logs, lines, eval_results + + +def parse_log_history(log_history): + """ + Parse the `log_history` of a Trainer to get the intermediate and final evaluation results. + """ + idx = 0 + while idx < len(log_history) and "train_runtime" not in log_history[idx]: + idx += 1 + + # If there are no training logs + if idx == len(log_history): + idx -= 1 + while idx >= 0 and "eval_loss" not in log_history[idx]: + idx -= 1 + + if idx >= 0: + return None, None, log_history[idx] + else: + return None, None, None + + # From now one we can assume we have training logs: + train_log = log_history[idx] + lines = [] + training_loss = "No log" + for i in range(idx): + if "loss" in log_history[i]: + training_loss = log_history[i]["loss"] + if "eval_loss" in log_history[i]: + metrics = log_history[i].copy() + _ = metrics.pop("total_flos", None) + epoch = metrics.pop("epoch", None) + step = metrics.pop("step", None) + _ = metrics.pop("eval_runtime", None) + _ = metrics.pop("eval_samples_per_second", None) + _ = metrics.pop("eval_steps_per_second", None) + _ = metrics.pop("eval_jit_compilation_time", None) + values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step} + for k, v in metrics.items(): + if k == "eval_loss": + values["Validation Loss"] = v + else: + splits = k.split("_") + name = " ".join([part.capitalize() for part in splits[1:]]) + values[name] = v + lines.append(values) + + idx = len(log_history) - 1 + while idx >= 0 and "eval_loss" not in log_history[idx]: + idx -= 1 + + if idx > 0: + eval_results = {} + for key, value in log_history[idx].items(): + key = key.removeprefix("eval_") + if key not in ["runtime", "samples_per_second", "steps_per_second", "epoch", "step"]: + camel_cased_key = " ".join([part.capitalize() for part in key.split("_")]) + eval_results[camel_cased_key] = value + return train_log, lines, eval_results + else: + return train_log, lines, None + + +def extract_hyperparameters_from_keras(model): + from .modeling_tf_utils import keras + + hyperparameters = {} + if hasattr(model, "optimizer") and model.optimizer is not None: + hyperparameters["optimizer"] = model.optimizer.get_config() + else: + hyperparameters["optimizer"] = None + hyperparameters["training_precision"] = keras.mixed_precision.global_policy().name + + return hyperparameters + + +def _maybe_round(v, decimals=4): + if isinstance(v, float) and len(str(v).split(".")) > 1 and len(str(v).split(".")[1]) > decimals: + return f"{v:.{decimals}f}" + return str(v) + + +def _regular_table_line(values, col_widths): + values_with_space = [f"| {v}" + " " * (w - len(v) + 1) for v, w in zip(values, col_widths)] + return "".join(values_with_space) + "|\n" + + +def _second_table_line(col_widths): + values = ["|:" + "-" * w + ":" for w in col_widths] + return "".join(values) + "|\n" + + +def make_markdown_table(lines): + """ + Create a nice Markdown table from the results in `lines`. + """ + if lines is None or len(lines) == 0: + return "" + col_widths = {key: len(str(key)) for key in lines[0]} + for line in lines: + for key, value in line.items(): + if col_widths[key] < len(_maybe_round(value)): + col_widths[key] = len(_maybe_round(value)) + + table = _regular_table_line(list(lines[0].keys()), list(col_widths.values())) + table += _second_table_line(list(col_widths.values())) + for line in lines: + table += _regular_table_line([_maybe_round(v) for v in line.values()], list(col_widths.values())) + return table + + +_TRAINING_ARGS_KEYS = [ + "learning_rate", + "train_batch_size", + "eval_batch_size", + "seed", +] + + +def extract_hyperparameters_from_trainer(trainer): + hyperparameters = {k: getattr(trainer.args, k) for k in _TRAINING_ARGS_KEYS} + + if trainer.args.parallel_mode not in [ParallelMode.NOT_PARALLEL, ParallelMode.NOT_DISTRIBUTED]: + hyperparameters["distributed_type"] = ( + "multi-GPU" if trainer.args.parallel_mode == ParallelMode.DISTRIBUTED else trainer.args.parallel_mode.value + ) + if trainer.args.world_size > 1: + hyperparameters["num_devices"] = trainer.args.world_size + if trainer.args.gradient_accumulation_steps > 1: + hyperparameters["gradient_accumulation_steps"] = trainer.args.gradient_accumulation_steps + + total_train_batch_size = ( + trainer.args.train_batch_size * trainer.args.world_size * trainer.args.gradient_accumulation_steps + ) + if total_train_batch_size != hyperparameters["train_batch_size"]: + hyperparameters["total_train_batch_size"] = total_train_batch_size + total_eval_batch_size = trainer.args.eval_batch_size * trainer.args.world_size + if total_eval_batch_size != hyperparameters["eval_batch_size"]: + hyperparameters["total_eval_batch_size"] = total_eval_batch_size + + if trainer.args.optim: + optimizer_name = trainer.args.optim + optimizer_args = trainer.args.optim_args if trainer.args.optim_args else "No additional optimizer arguments" + + if "adam" in optimizer_name.lower(): + hyperparameters["optimizer"] = ( + f"Use {optimizer_name} with betas=({trainer.args.adam_beta1},{trainer.args.adam_beta2}) and" + f" epsilon={trainer.args.adam_epsilon} and optimizer_args={optimizer_args}" + ) + else: + hyperparameters["optimizer"] = f"Use {optimizer_name} and the args are:\n{optimizer_args}" + + hyperparameters["lr_scheduler_type"] = trainer.args.lr_scheduler_type.value + if trainer.args.warmup_ratio != 0.0: + hyperparameters["lr_scheduler_warmup_ratio"] = trainer.args.warmup_ratio + if trainer.args.warmup_steps != 0.0: + hyperparameters["lr_scheduler_warmup_steps"] = trainer.args.warmup_steps + if trainer.args.max_steps != -1: + hyperparameters["training_steps"] = trainer.args.max_steps + else: + hyperparameters["num_epochs"] = trainer.args.num_train_epochs + + if trainer.args.fp16: + if trainer.use_apex: + hyperparameters["mixed_precision_training"] = f"Apex, opt level {trainer.args.fp16_opt_level}" + else: + hyperparameters["mixed_precision_training"] = "Native AMP" + + if trainer.args.label_smoothing_factor != 0.0: + hyperparameters["label_smoothing_factor"] = trainer.args.label_smoothing_factor + + return hyperparameters diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_attn_mask_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_attn_mask_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0be1b3ed25531a6bd67aa143b0883815b358678d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_attn_mask_utils.py @@ -0,0 +1,487 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +IMPORTANT NOTICE: Every class and function in this file is deprecated in favor of using the much more general +`masking_utils.py` primitives. New code should not rely on it, it is only kept for backward compatibility for now, +and will be removed in the future. +""" + +from dataclasses import dataclass +from typing import Optional, Union + +import torch + +from .utils.import_utils import is_torchdynamo_compiling + + +@dataclass +class AttentionMaskConverter: + """ + A utility attention mask class that allows one to: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, + key_value_length) that can be multiplied with attention scores + + Examples: + + ```python + >>> import torch + >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + >>> converter = AttentionMaskConverter(True) + >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32) + tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]]) + ``` + + Parameters: + is_causal (`bool`): + Whether the attention mask should be a uni-directional (causal) or bi-directional mask. + + sliding_window (`int`, *optional*): + Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. + """ + + is_causal: bool + sliding_window: int + + def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): + self.is_causal = is_causal + self.sliding_window = sliding_window + + if self.sliding_window is not None and self.sliding_window <= 0: + raise ValueError( + f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`" + ) + + def to_causal_4d( + self, + batch_size: int, + query_length: int, + key_value_length: int, + dtype: torch.dtype, + device: Union[torch.device, "str"] = "cpu", + ) -> Optional[torch.Tensor]: + """ + Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative + bias to upper right hand triangular matrix (causal mask). + """ + if not self.is_causal: + raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") + + # If shape is not cached, create a new causal mask and cache it + input_shape = (batch_size, query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if input_shape[-1] > 1 or self.sliding_window is not None: + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + + return causal_4d_mask + + def to_4d( + self, + attention_mask_2d: torch.Tensor, + query_length: int, + dtype: torch.dtype, + key_value_length: Optional[int] = None, + ) -> torch.Tensor: + """ + Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, + key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is + causal, a causal mask will be added. + """ + input_shape = (attention_mask_2d.shape[0], query_length) + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + if key_value_length is None: + raise ValueError( + "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." + ) + + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=attention_mask_2d.device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + elif self.sliding_window is not None: + raise NotImplementedError("Sliding window is currently only implemented for causal masking") + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( + attention_mask_2d.device + ) + + if causal_4d_mask is not None: + expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min) + + # expanded_attn_mask + causal_4d_mask can cause some overflow + expanded_4d_mask = expanded_attn_mask + + return expanded_4d_mask + + @staticmethod + def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window - 1 + + context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal) + # Recent changes in PyTorch prevent mutations on tensors converted with aten::_to_copy + # See https://github.com/pytorch/pytorch/issues/127571 + if is_torchdynamo_compiling(): + mask = mask.clone() + mask.masked_fill_(context_mask, torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + @staticmethod + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = torch.tensor(1.0, dtype=dtype) - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + @staticmethod + def _unmask_unattended( + expanded_mask: torch.FloatTensor, + min_dtype: float, + ): + # fmt: off + """ + Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when + using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + Details: https://github.com/pytorch/pytorch/issues/110213 + + `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. + `attention_mask` is [bsz, src_seq_len]. + + The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias. + + For example, if `expanded_mask` is (e.g. here left-padding case) + ``` + [[[[0, 0, 0], + [0, 0, 0], + [0, 0, 1]]], + [[[1, 0, 0], + [1, 1, 0], + [1, 1, 1]]], + [[[0, 0, 0], + [0, 1, 0], + [0, 1, 1]]]] + ``` + then the modified `expanded_mask` will be + ``` + [[[[1, 1, 1], <-- modified + [1, 1, 1], <-- modified + [0, 0, 1]]], + [[[1, 0, 0], + [1, 1, 0], + [1, 1, 1]]], + [[[1, 1, 1], <-- modified + [0, 1, 0], + [0, 1, 1]]]] + ``` + """ + # fmt: on + if expanded_mask.dtype == torch.bool: + raise ValueError( + "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor." + ) + + return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True)) + + @staticmethod + def _ignore_causal_mask_sdpa( + attention_mask: Optional[torch.Tensor], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, + is_training: bool = False, + ) -> bool: + """ + Detects whether the optional user-specified attention_mask & the automatically created causal mask can be + ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. + + In case no token is masked in the `attention_mask` argument, if `query_length == 1` or + `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks, + allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is + passed). + """ + + _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1] + key_value_length = query_length + past_key_values_length + + is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling() + + ignore_causal_mask = False + + if attention_mask is None: + # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input + # shape, thus SDPA's `is_causal` argument is rightfully updated + # (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using + # `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is + # hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` + # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). + # Thus, we only set `ignore_causal_mask = True` if the model is set to training. + # + # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` + # ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor"). + if ( + (is_training or not is_tracing) + and (query_length == 1 or key_value_length == query_length) + and (sliding_window is None or key_value_length < sliding_window) + ): + ignore_causal_mask = True + elif sliding_window is None or key_value_length < sliding_window: + if len(attention_mask.shape) == 4: + return False + elif not is_tracing and torch.all(attention_mask == 1): + if query_length == 1 or key_value_length == query_length: + # For query_length == 1, causal attention and bi-directional attention are the same. + ignore_causal_mask = True + + # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore + # the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in + # SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. + # Reference: https://github.com/pytorch/pytorch/issues/108108 + # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3. + + return ignore_causal_mask + + +def _prepare_4d_causal_attention_mask( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, tuple, list], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + attention_mask (`torch.Tensor` or `None`): + A 2D attention mask of shape `(batch_size, key_value_length)` + input_shape (`tuple(int)` or `list(int)` or `torch.Size`): + The input shape should be a tuple that defines `(batch_size, query_length)`. + inputs_embeds (`torch.Tensor`): + The embedded inputs as a torch Tensor. + past_key_values_length (`int`): + The length of the key value cache. + sliding_window (`int`, *optional*): + If the model uses windowed attention, a sliding window should be passed. + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = input_shape[-1] + past_key_values_length + + # 4d mask is passed through the layers + if attention_mask is not None and len(attention_mask.shape) == 2: + attention_mask = attn_mask_converter.to_4d( + attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype + ) + elif attention_mask is not None and len(attention_mask.shape) == 4: + expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) + if tuple(attention_mask.shape) != expected_shape: + raise ValueError( + f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." + ) + else: + # if the 4D mask has correct shape - invert it and fill with negative infinity + inverted_mask = 1.0 - attention_mask + attention_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min + ) + else: + attention_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + + return attention_mask + + +# Adapted from _prepare_4d_causal_attention_mask +def _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, tuple, list], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """ + Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`. + + In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and + `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks, + allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = input_shape[-1] + past_key_values_length + + # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` + # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. + # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). + is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling() + + ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + sliding_window=sliding_window, + ) + + if ignore_causal_mask: + expanded_4d_mask = None + elif attention_mask is None: + expanded_4d_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + else: + if attention_mask.dim() == 4: + expanded_4d_mask = attention_mask + else: + expanded_4d_mask = attn_mask_converter.to_4d( + attention_mask, + input_shape[-1], + dtype=inputs_embeds.dtype, + key_value_length=key_value_length, + ) + + # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + if not is_tracing and expanded_4d_mask.device.type == "cuda": + expanded_4d_mask = AttentionMaskConverter._unmask_unattended( + expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min + ) + + return expanded_4d_mask + + +def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + tgt_len (`int`): + The target length or query length the created mask shall have. + """ + return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + +def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + tgt_len (`int`): + The target length or query length the created mask shall have. + """ + _, key_value_length = mask.shape + tgt_len = tgt_len if tgt_len is not None else key_value_length + + is_tracing = torch.jit.is_tracing() or isinstance(mask, torch.fx.Proxy) or is_torchdynamo_compiling() + + # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows. + if not is_tracing and torch.all(mask == 1): + return None + else: + return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + +def _create_4d_causal_attention_mask( + input_shape: Union[torch.Size, tuple, list], + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, +) -> Optional[torch.Tensor]: + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` + + Args: + input_shape (`tuple(int)` or `list(int)` or `torch.Size`): + The input shape should be a tuple that defines `(batch_size, query_length)`. + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + device (`int`): + The torch device the created mask shall have. + sliding_window (`int`, *optional*): + If the model uses windowed attention, a sliding window should be passed. + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = past_key_values_length + input_shape[-1] + attention_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device + ) + + return attention_mask diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5312b0dd9cd0b40f6c7b64356de203f579e8b263 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py @@ -0,0 +1,668 @@ +# Copyright 2025 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import os +from functools import partial +from typing import Optional, TypedDict + +import torch +import torch.nn.functional as F + +from .utils import ( + is_flash_attn_2_available, + is_flash_attn_3_available, + is_flash_attn_greater_or_equal_2_10, + is_torch_npu_available, + logging, +) + + +logger = logging.get_logger(__name__) + + +# TODO Deprecate when all models have the attention interface +def flash_attn_supports_top_left_mask(): + if is_flash_attn_3_available(): + return False + if is_flash_attn_2_available(): + return not is_flash_attn_greater_or_equal_2_10() + + from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask + + return is_npu_fa2_top_left_aligned_causal_mask() + + +# TODO Deprecate when all models have the attention interface +def is_flash_attn_available(): + return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available() + + +# `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves +_flash_fn = None +_flash_varlen_fn = None +_pad_fn = None +_unpad_fn = None + +# function that processes kwargs, generalized to handle any supported kwarg within the function +_process_flash_kwargs_fn = None +# exceptions where hf API doesn't match the original flash attention API +_hf_api_to_flash_mapping = { + "dropout": "dropout_p", + "sliding_window": "window_size", +} + + +def _lazy_imports(implementation: Optional[str]): + """ + Lazy loads the respective flash attention implementations. + + Return: + flash_attn_func: The base flash attention function. + flash_attn_varlen_func: The flash attention function supporting variable sequence lengths, + e.g. for padding-free training. + pad_input: The function to pad inputs into one sequence and returning the respective kwargs. + unpad_input: The function to unpad outputs based on the kwargs (from pad_input). + """ + is_fa2 = is_flash_attn_2_available() + is_fa3 = is_flash_attn_3_available() + + pad_input, unpad_input = _pad_input, _unpad_input + + if (implementation == "flash_attention_2" and is_fa2) or (implementation is None and is_fa2 and not is_fa3): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import pad_input, unpad_input + elif is_torch_npu_available(): + # Package `flash-attn` is unavailable on Ascend NPU, which will cause ImportError + # Flash-Attention2 related apis for Ascend NPU must be imported from `.integrations.npu_flash_attention` module + from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func + from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func + else: + if implementation == "flash_attention_3" or (implementation is None and is_fa3): + from flash_attn_interface import flash_attn_func, flash_attn_varlen_func + # Kernels fallback + else: + flash_attn_func = getattr(implementation, "flash_attn_func", None) + flash_attn_varlen_func = getattr(implementation, "flash_attn_varlen_func", None) + if flash_attn_varlen_func is None or flash_attn_func is None: + raise ValueError( + f"Could not find the currently requested flash attention implementation at `{implementation}`." + f"Make sure that you request a valid kernel from the hub, e.g. `kernels-community/flash-attn`." + ) + + return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input + + +def _lazy_define_process_function(flash_function): + """ + Depending on the version and kernel some features are not supported. Due to limitations in + `torch.compile`, we opt to statically type which (optional) kwarg parameters are supported + within `_process_flash_attention_kwargs`. + + NOTE: While all supported kwargs are marked as `True`, everything else is marked as `False`. + This might be confusing for kwargs that we use in any case, e.g. `is_causal`. + """ + + flash_parameters = inspect.signature(flash_function).parameters + process_parameters = inspect.signature(_process_flash_attention_kwargs).parameters + + supports_mapping = {} + for param in process_parameters: + fa_param = _hf_api_to_flash_mapping.get(param, param) + supports_mapping[fa_param] = fa_param in flash_parameters + + return partial(_process_flash_attention_kwargs, supports_mapping=supports_mapping) + + +def lazy_import_flash_attention(implementation: Optional[str], force_import: Optional[bool] = False): + """ + Lazily import flash attention and return the respective functions + flags. + + NOTE: For fullgraph, this needs to be called before compile, while no fullgraph can + work without preloading. See `load_and_register_kernel` in `integrations.hub_kernels`. + """ + global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn + if force_import or any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]): + _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn = _lazy_imports(implementation) + + global _process_flash_kwargs_fn + if force_import or _process_flash_kwargs_fn is None: + _process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn) + + return (_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn + + +def _index_first_axis(tensor, indices): + """ + A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis, + after flattening the first two dimensions of the tensor. This is functionally equivalent to + FA2's `index_first_axis` and replaces the need to import it. + """ + # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first + # two dimensions to get (total_tokens, ...) before indexing. + reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:]) + return reshaped_tensor[indices] + + +def _unpad_input(hidden_states, attention_mask, unused_mask=None): + """ + unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. + + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + return ( + _index_first_axis(hidden_states, indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def _pad_input(hidden_states, indices, batch, seqlen): + """ + pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. + + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[1:] + output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) + output[indices] = hidden_states + return output.view(batch, seqlen, *dim) + + +def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Arguments: + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`torch.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + # NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile, + # this might cause a graph break + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _upad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + unpad_input_func, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. + This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary + tensors for query, key, value tensors. + + Arguments: + query_layer (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + unpad_input_func: + The function to use for unpadding the input tensors. + + Return: + query_layer (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + # With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage + # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores + if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]): + key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :] + + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = _index_first_axis(key_layer, indices_k) + value_layer = _index_first_axis(value_layer, indices_k) + if query_length == kv_seq_len: + query_layer = _index_first_axis(query_layer, indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def prepare_fa_kwargs_from_position_ids(position_ids): + """ + This function returns all the necessary kwargs to call `flash_attn_varlen_func` extracted from position_ids. + + Arguments: + position_ids (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into + ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, + `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device} + + position_ids = position_ids.view(-1) + indices_q = (position_ids == 0).nonzero().view(-1) + + cu_seq_lens_q = torch.cat( + ( + indices_q.to(**tensor_kwargs), + torch.tensor(position_ids.size(), **tensor_kwargs), + ) + ) + cu_seq_lens_k = cu_seq_lens_q + + # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424 + # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing + # for some models (e.g. qwen2-vl). + max_length_q = cu_seq_lens_q.diff().max() + # NOTE: With torch compile, this will cause a graph break if you don't set + # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call + # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass. + # This is a limitation of flash attention API, as the function `flash_attn_varlen_func` + # requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`. + max_length_q = max_length_q.item() + max_length_k = max_length_q + + return (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) + + +def _prepare_from_posids(query, key, value, position_ids): + """ + This function returns necessary arguments to call `flash_attn_varlen_func`. + All three query, key, value states will be flattened. + Cumulative lengths of each examples in the batch will be extracted from position_ids. + NOTE: ideally cumulative lengths should be prepared at the data collator stage + + Arguments: + query (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + position_ids (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + query (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + query = query.contiguous().view(-1, query.size(-2), query.size(-1)) + key = key.contiguous().view(-1, key.size(-2), key.size(-1)) + value = value.contiguous().view(-1, value.size(-2), value.size(-1)) + + (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fa_kwargs_from_position_ids(position_ids) + + return (query, key, value, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k)) + + +def _is_packed_sequence(position_ids, batch_size): + """ + Check the position ids whether packed sequences are indicated or not + 1. Position ids exist + 2. Flattened sequences only are supported + 3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences + """ + if position_ids is None: + return False + + increasing_position_sequences = ( + torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min() + ) + return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool() + + +def fa_peft_integration_check( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + target_dtype: Optional[torch.dtype] = None, +): + """ + PEFT usually casts the layer norms in float32 for training stability reasons + therefore the input hidden states gets silently casted in float32. Hence, we need + cast them back in float16 / bfloat16 just to be sure everything works as expected. + This might slowdown training & inference so it is recommended to not cast the LayerNorms! + """ + if target_dtype and q.dtype == torch.float32: + logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility.") + q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype) + return q, k, v + + +class FlashAttentionKwargs(TypedDict, total=False): + """ + Keyword arguments for Flash Attention with Compile. + + Attributes: + cu_seq_lens_q (`torch.LongTensor`, *optional*) + Gets cumulative sequence length for query state. + cu_seq_lens_k (`torch.LongTensor`, *optional*) + Gets cumulative sequence length for key state. + max_length_q (`int`, *optional*): + Maximum sequence length for query state. + max_length_k (`int`, *optional*): + Maximum sequence length for key state. + """ + + cu_seq_lens_q: Optional[torch.LongTensor] + cu_seq_lens_k: Optional[torch.LongTensor] + max_length_q: Optional[int] + max_length_k: Optional[int] + + +def _process_flash_attention_kwargs( + query_length: int, + key_length: int, + is_causal: bool, + dropout: float = 0.0, + softmax_scale: Optional[float] = None, + sliding_window: Optional[int] = None, + use_top_left_mask: bool = False, + softcap: Optional[float] = None, + deterministic: Optional[bool] = None, + s_aux: Optional[torch.Tensor] = None, + supports_mapping: Optional[dict[str, bool]] = None, + **kwargs, +): + """ + Returns a set of kwargs that are passed down to the according flash attention function based on + requested features and whether it is supported - depends on the version and kernel implementation + which is dynamically configured at `lazy_import_flash_attention`. The (un)supported features can be + inspected in `supports_mapping`, see `_lazy_define_process_function` for more details. + + Args: + query_length (`int`): + Length of the query states + key_length (`int`): + Length of the key states + is_causal (`bool`): + Whether we perform causal (decoder) attention or full attention. + dropout (`float`): + Attention dropout. + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to `1 / sqrt(head_dim)`. + sliding_window (`int`, *optional*): + The size of the sliding window, i.e. we look at a max of `sliding_window` tokens back. + use_top_left_mask (`bool`): + Deprecated behavior of older versions of flash attention requiring different masking. + softcap (`float`, *optional*): + Softcap for the attention logits, used e.g. in gemma2. + deterministic (`bool`, *optional*): + Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. + s_aux (`torch.Tensor`, *optional*): + Attention sink auxiliary that adds a `bias` to the attention calculation via an additional head. + Return: + flash_kwargs (`dict`): + A dict of kwargs that are requested and supported. + """ + flash_kwargs = { + "causal": is_causal and not (use_top_left_mask and query_length == 1), + "softmax_scale": softmax_scale, + } + + if supports_mapping["dropout_p"]: + flash_kwargs["dropout_p"] = dropout + + if supports_mapping["window_size"] and sliding_window is not None and key_length > sliding_window: + # The flash attention API sets inclusive boundaries, i.e. (4, 0) would take 4 tokens to the left + # and the current token for a total size of 5. However, we usually define our window sizes by + # their total window size (when causal). Encoder models as of now seldom use SWA and when they + # do, they have a custom workaround (e.g. ModernBERT) which would align with this symmetric logic, i.e. + # for a total of `2*sliding_window + 1`. + flash_kwargs["window_size"] = (sliding_window - 1, sliding_window - 1) + + if supports_mapping["deterministic"]: + flash_kwargs["deterministic"] = ( + deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + ) + + if supports_mapping["softcap"] and softcap is not None: + flash_kwargs["softcap"] = softcap + + # Only within kernel implementation atm + if supports_mapping["s_aux"] and s_aux is not None: + flash_kwargs["s_aux"] = s_aux + + return flash_kwargs + + +def _flash_attention_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + query_length: int, + is_causal: bool, + dropout: float = 0.0, + position_ids: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + sliding_window: Optional[int] = None, + use_top_left_mask: bool = False, + softcap: Optional[float] = None, + deterministic: Optional[bool] = None, + cu_seq_lens_q: Optional[torch.LongTensor] = None, + cu_seq_lens_k: Optional[torch.LongTensor] = None, + max_length_q: Optional[int] = None, + max_length_k: Optional[int] = None, + target_dtype: Optional[torch.dtype] = None, + implementation: Optional[str] = None, + **kwargs, +): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + (Optional) kwargs are described further in `_process_flash_attention_kwargs` and `FlashAttentionKwargs`. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`, *optional*): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + implementation (`str`, *optional*): + The attention implementation to use. If None, will default to the one based on the environment. + """ + (flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_flash_kwargs_fn = lazy_import_flash_attention( + implementation + ) + + # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op + query_states, key_states, value_states = fa_peft_integration_check( + query_states, key_states, value_states, target_dtype + ) + + # Extract the flash attention kwargs that have been requested (and are supported by the implementation) + flash_kwargs = process_flash_kwargs_fn( + query_length=query_length, + key_length=key_states.size(1), + is_causal=is_causal, + dropout=dropout, + softmax_scale=softmax_scale, + sliding_window=sliding_window, + use_top_left_mask=use_top_left_mask, + softcap=softcap, + deterministic=deterministic, + **kwargs, + ) + + # We will use `flash_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases: + # Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`. + # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to + # use `flash_varlen_fn` knowing we already have all necessary the kwargs. + # + # NOTE: it is user's responsibility to take care of flattening `position_ids` if that's needed by the model. + # See #39121 for more information. + is_fa_with_position_ids = _is_packed_sequence(position_ids, batch_size=query_states.size(0)) + is_fa_with_varlen_kwargs = all( + kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) + ) + + # Contains at least one padding token in the sequence + if attention_mask is not None: + q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input( + query_states, key_states, value_states, attention_mask, query_length, unpad_fn + ) + + # TODO for now this is required to work with + # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py + if "mps" in str(q.device): + cu_seq_lens_k = cu_seq_lens_k.clone() + + out_unpad = flash_varlen_fn( + q, + k, + v, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + max_seqlen_q=max_length_q, + max_seqlen_k=max_length_k, + **flash_kwargs, + ) + if isinstance(out_unpad, tuple): + out_unpad = out_unpad[0] + + out = pad_fn(out_unpad, indices_q, query_states.size(0), query_length) + + # Padding free, i.e. sequences flattened into one total sequence + elif is_fa_with_varlen_kwargs or is_fa_with_position_ids: + if cu_seq_lens_q is None or cu_seq_lens_k is None: + q, k, v, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _prepare_from_posids( + query_states, key_states, value_states, position_ids + ) + else: + q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) + k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) + v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) + + # TODO for now this is required to work with + # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py + if "mps" in str(q.device): + cu_seq_lens_k = cu_seq_lens_k.clone() + + out = flash_varlen_fn( + q, + k, + v, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + max_seqlen_q=max_length_q, + max_seqlen_k=max_length_k, + **flash_kwargs, + ) + if isinstance(out, tuple): + out = out[0] + + out = out.view(query_states.size(0), -1, out.size(-2), out.size(-1)) + + # No padding + else: + out = flash_fn(query_states, key_states, value_states, **flash_kwargs) + if isinstance(out, tuple): + out = out[0] + + return out diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_flax_outputs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_flax_outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..5a25a6059a255659c6d900b35d2ffa7cab57f071 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_flax_outputs.py @@ -0,0 +1,700 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import flax +import jax.numpy as jnp + +from .utils import ModelOutput + + +@flax.struct.dataclass +class FlaxBaseModelOutput(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: Optional[jnp.ndarray] = None + hidden_states: Optional[tuple[jnp.ndarray]] = None + attentions: Optional[tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxBaseModelOutputWithNoAttention(ModelOutput): + """ + Base class for model's outputs, with potential hidden states. + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the + model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[jnp.ndarray] = None + hidden_states: Optional[tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxBaseModelOutputWithPoolingAndNoAttention(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`): + Last layer hidden-state after a pooling operation on the spatial dimensions. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the + model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[jnp.ndarray] = None + pooler_output: Optional[jnp.ndarray] = None + hidden_states: Optional[tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxImageClassifierOutputWithNoAttention(ModelOutput): + """ + Base class for outputs of image classification models. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also + called feature maps) of the model at the output of each stage. + """ + + logits: Optional[jnp.ndarray] = None + hidden_states: Optional[tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxBaseModelOutputWithPast(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`dict[str, jnp.ndarray]`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: Optional[jnp.ndarray] = None + past_key_values: Optional[dict[str, jnp.ndarray]] = None + hidden_states: Optional[tuple[jnp.ndarray]] = None + attentions: Optional[tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxBaseModelOutputWithPooling(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) further processed by a + Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence + prediction (classification) objective during pretraining. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: Optional[jnp.ndarray] = None + pooler_output: Optional[jnp.ndarray] = None + hidden_states: Optional[tuple[jnp.ndarray]] = None + attentions: Optional[tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + """ + + last_hidden_state: Optional[jnp.ndarray] = None + pooler_output: Optional[jnp.ndarray] = None + hidden_states: Optional[tuple[jnp.ndarray]] = None + past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None + attentions: Optional[tuple[jnp.ndarray]] = None + cross_attentions: Optional[tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: Optional[jnp.ndarray] = None + past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None + hidden_states: Optional[tuple[jnp.ndarray]] = None + attentions: Optional[tuple[jnp.ndarray]] = None + cross_attentions: Optional[tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxSeq2SeqModelOutput(ModelOutput): + """ + Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + last_hidden_state: Optional[jnp.ndarray] = None + past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None + decoder_hidden_states: Optional[tuple[jnp.ndarray]] = None + decoder_attentions: Optional[tuple[jnp.ndarray]] = None + cross_attentions: Optional[tuple[jnp.ndarray]] = None + encoder_last_hidden_state: Optional[jnp.ndarray] = None + encoder_hidden_states: Optional[tuple[jnp.ndarray]] = None + encoder_attentions: Optional[tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxCausalLMOutputWithCrossAttentions(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Cross attentions weights after the attention softmax, used to compute the weighted average in the + cross-attention heads. + past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `jnp.ndarray` tuples of length `config.n_layers`, with each tuple containing the cached key, value + states of the self-attention and the cross-attention layers if model is used in encoder-decoder setting. + Only relevant if `config.is_decoder = True`. + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + """ + + logits: Optional[jnp.ndarray] = None + past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None + hidden_states: Optional[tuple[jnp.ndarray]] = None + attentions: Optional[tuple[jnp.ndarray]] = None + cross_attentions: Optional[tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxMaskedLMOutput(ModelOutput): + """ + Base class for masked language models outputs. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: Optional[jnp.ndarray] = None + hidden_states: Optional[tuple[jnp.ndarray]] = None + attentions: Optional[tuple[jnp.ndarray]] = None + + +FlaxCausalLMOutput = FlaxMaskedLMOutput + + +@flax.struct.dataclass +class FlaxSeq2SeqLMOutput(ModelOutput): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + logits: Optional[jnp.ndarray] = None + past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None + decoder_hidden_states: Optional[tuple[jnp.ndarray]] = None + decoder_attentions: Optional[tuple[jnp.ndarray]] = None + cross_attentions: Optional[tuple[jnp.ndarray]] = None + encoder_last_hidden_state: Optional[jnp.ndarray] = None + encoder_hidden_states: Optional[tuple[jnp.ndarray]] = None + encoder_attentions: Optional[tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxNextSentencePredictorOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: Optional[jnp.ndarray] = None + hidden_states: Optional[tuple[jnp.ndarray]] = None + attentions: Optional[tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxSequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: Optional[jnp.ndarray] = None + hidden_states: Optional[tuple[jnp.ndarray]] = None + attentions: Optional[tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxSeq2SeqSequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of sequence-to-sequence sentence classification models. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + logits: Optional[jnp.ndarray] = None + past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None + decoder_hidden_states: Optional[tuple[jnp.ndarray]] = None + decoder_attentions: Optional[tuple[jnp.ndarray]] = None + cross_attentions: Optional[tuple[jnp.ndarray]] = None + encoder_last_hidden_state: Optional[jnp.ndarray] = None + encoder_hidden_states: Optional[tuple[jnp.ndarray]] = None + encoder_attentions: Optional[tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxMultipleChoiceModelOutput(ModelOutput): + """ + Base class for outputs of multiple choice models. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, num_choices)`): + *num_choices* is the second dimension of the input tensors. (see *input_ids* above). + + Classification scores (before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: Optional[jnp.ndarray] = None + hidden_states: Optional[tuple[jnp.ndarray]] = None + attentions: Optional[tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxTokenClassifierOutput(ModelOutput): + """ + Base class for outputs of token classification models. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.num_labels)`): + Classification scores (before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: Optional[jnp.ndarray] = None + hidden_states: Optional[tuple[jnp.ndarray]] = None + attentions: Optional[tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxQuestionAnsweringModelOutput(ModelOutput): + """ + Base class for outputs of question answering models. + + Args: + start_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + start_logits: Optional[jnp.ndarray] = None + end_logits: Optional[jnp.ndarray] = None + hidden_states: Optional[tuple[jnp.ndarray]] = None + attentions: Optional[tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput): + """ + Base class for outputs of sequence-to-sequence question answering models. + + Args: + start_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + start_logits: Optional[jnp.ndarray] = None + end_logits: Optional[jnp.ndarray] = None + past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None + decoder_hidden_states: Optional[tuple[jnp.ndarray]] = None + decoder_attentions: Optional[tuple[jnp.ndarray]] = None + cross_attentions: Optional[tuple[jnp.ndarray]] = None + encoder_last_hidden_state: Optional[jnp.ndarray] = None + encoder_hidden_states: Optional[tuple[jnp.ndarray]] = None + encoder_attentions: Optional[tuple[jnp.ndarray]] = None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_flax_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_flax_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bc9a4d473f36f95bb13d6de17dc0bfa7cfae2279 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_flax_utils.py @@ -0,0 +1,1274 @@ +# coding=utf-8 +# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import gc +import json +import os +import warnings +from functools import partial +from pickle import UnpicklingError +from typing import Any, Optional, Union + +import flax.linen as nn +import jax +import jax.numpy as jnp +import msgpack.exceptions +from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.serialization import from_bytes, to_bytes +from flax.traverse_util import flatten_dict, unflatten_dict +from jax.random import PRNGKey + +from .configuration_utils import PretrainedConfig +from .dynamic_module_utils import custom_object_save +from .generation import FlaxGenerationMixin, GenerationConfig +from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict +from .utils import ( + FLAX_WEIGHTS_INDEX_NAME, + FLAX_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + PushToHubMixin, + add_code_sample_docstrings, + add_start_docstrings_to_model_forward, + cached_file, + copy_func, + download_url, + has_file, + is_offline_mode, + is_remote_url, + logging, + replace_return_docstrings, +) +from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files +from .utils.import_utils import is_safetensors_available + + +if is_safetensors_available(): + from safetensors import safe_open + from safetensors.flax import load_file as safe_load_file + from safetensors.flax import save_file as safe_save_file + +logger = logging.get_logger(__name__) + + +def quick_gelu(x): + return x * jax.nn.sigmoid(1.702 * x) + + +ACT2FN = { + "gelu": partial(nn.gelu, approximate=False), + "relu": nn.relu, + "silu": nn.swish, + "swish": nn.swish, + "gelu_new": partial(nn.gelu, approximate=True), + "quick_gelu": quick_gelu, + "gelu_pytorch_tanh": partial(nn.gelu, approximate=True), + "tanh": nn.tanh, +} + + +def flax_shard_checkpoint(params, max_shard_size="10GB"): + """ + Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a + given size. The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so + there is no optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For + example, if the limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as + [6GB], [6+2GB], [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. + + + + If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will + have a size greater than `max_shard_size`. + + + + Args: + params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters. + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit + (like `"5MB"`). + """ + max_shard_size = convert_file_size_to_int(max_shard_size) + + sharded_state_dicts = [] + current_block = {} + current_block_size = 0 + total_size = 0 + + # flatten the weights to chunk + weights = flatten_dict(params, sep="/") + for item in weights: + weight_size = weights[item].size * weights[item].dtype.itemsize + + # If this weight is going to tip up over the maximal size, we split. + if current_block_size + weight_size > max_shard_size: + sharded_state_dicts.append(current_block) + current_block = {} + current_block_size = 0 + + current_block[item] = weights[item] + current_block_size += weight_size + total_size += weight_size + + # Add the last block + sharded_state_dicts.append(current_block) + + # If we only have one shard, we return it + if len(sharded_state_dicts) == 1: + return {FLAX_WEIGHTS_NAME: sharded_state_dicts[0]}, None + + # Otherwise, let's build the index + weight_map = {} + shards = {} + for idx, shard in enumerate(sharded_state_dicts): + shard_file = FLAX_WEIGHTS_NAME.replace(".msgpack", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.msgpack") + shards[shard_file] = shard + for weight_name in shard: + weight_map[weight_name] = shard_file + + # Add the metadata + metadata = {"total_size": total_size} + index = {"metadata": metadata, "weight_map": weight_map} + return shards, index + + +class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): + r""" + Base class for all models. + + [`FlaxPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading, + downloading and saving models. + + Class attributes (overridden by derived classes): + + - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class + for this model architecture. + - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived + classes of the same architecture adding modules on top of the base model. + - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP + models, `pixel_values` for vision models and `input_values` for speech models). + """ + + config_class = None + base_model_prefix = "" + main_input_name = "input_ids" + _auto_class = None + _missing_keys = set() + + def __init__( + self, + config: PretrainedConfig, + module: nn.Module, + input_shape: tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + ): + logger.warning_once( + "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We " + "recommend migrating to PyTorch classes or pinning your version of Transformers." + ) + if config is None: + raise ValueError("config cannot be None") + + if module is None: + raise ValueError("module cannot be None") + + # Those are private to be exposed as typed property on derived classes. + self._config = config + self._module = module + + # Those are public as their type is generic to every derived classes. + self.key = PRNGKey(seed) + self.dtype = dtype + self.input_shape = input_shape + self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None + + # To check if the model was initialized automatically. + self._is_initialized = _do_init + + if _do_init: + # randomly initialized parameters + random_params = self.init_weights(self.key, input_shape) + params_shape_tree = jax.eval_shape(lambda params: params, random_params) + else: + init_fn = partial(self.init_weights, input_shape=input_shape) + params_shape_tree = jax.eval_shape(init_fn, self.key) + + logger.info( + "Model weights are not initialized as `_do_init` is set to `False`. " + f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights." + ) + + # get the shape of the parameters + self._params_shape_tree = params_shape_tree + + # save required_params as set + self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) + + # initialize the parameters + if _do_init: + self.params = random_params + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> dict: + raise NotImplementedError(f"init method has to be implemented for {self}") + + def enable_gradient_checkpointing(self): + raise NotImplementedError(f"gradient checkpointing method has to be implemented for {self}") + + @classmethod + def _from_config(cls, config, **kwargs): + """ + All context managers that the model should be initialized under go here. + """ + return cls(config, **kwargs) + + @property + def framework(self) -> str: + """ + :str: Identifies that this is a Flax model. + """ + return "flax" + + @property + def config(self) -> PretrainedConfig: + return self._config + + @property + def module(self) -> nn.Module: + return self._module + + @property + def params(self) -> Union[dict, FrozenDict]: + if not self._is_initialized: + raise ValueError( + "`params` cannot be accessed from model when the model is created with `_do_init=False`. " + "You must call `init_weights` manually and store the params outside of the model and " + "pass it explicitly where needed." + ) + return self._params + + @property + def required_params(self) -> set: + return self._required_params + + @property + def params_shape_tree(self) -> dict: + return self._params_shape_tree + + @params.setter + def params(self, params: Union[dict, FrozenDict]): + # don't set params if the model is not initialized + if not self._is_initialized: + raise ValueError( + "`params` cannot be set from model when the model is created with `_do_init=False`. " + "You store the params outside of the model." + ) + + if isinstance(params, FrozenDict): + params = unfreeze(params) + param_keys = set(flatten_dict(params).keys()) + if len(self.required_params - param_keys) > 0: + raise ValueError( + "Some parameters are missing. Make sure that `params` include the following " + f"parameters {self.required_params - param_keys}" + ) + self._params = params + + def _cast_floating_to(self, params: Union[dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: + """ + Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. + """ + + # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 + def conditional_cast(param): + if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): + param = param.astype(dtype) + return param + + if mask is None: + return jax.tree_util.tree_map(conditional_cast, params) + + flat_params = flatten_dict(params) + flat_mask, _ = jax.tree_util.tree_flatten(mask) + + for masked, key in zip(flat_mask, sorted(flat_params.keys())): + if masked: + flat_params[key] = conditional_cast(flat_params[key]) + + return unflatten_dict(flat_params) + + def to_bf16(self, params: Union[dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast + the `params` in place. + + This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full + half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params + you want to cast, and should be `False` for those you want to skip. + + Examples: + + ```python + >>> from transformers import FlaxBertModel + + >>> # load model + >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased") + >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision + >>> model.params = model.to_bf16(model.params) + >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) + >>> # then pass the mask as follows + >>> from flax import traverse_util + + >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased") + >>> flat_params = traverse_util.flatten_dict(model.params) + >>> mask = { + ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) + ... for path in flat_params + ... } + >>> mask = traverse_util.unflatten_dict(mask) + >>> model.params = model.to_bf16(model.params, mask) + ```""" + return self._cast_floating_to(params, jnp.bfloat16, mask) + + def to_fp32(self, params: Union[dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the + model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params + you want to cast, and should be `False` for those you want to skip + + Examples: + + ```python + >>> from transformers import FlaxBertModel + + >>> # Download model and configuration from huggingface.co + >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased") + >>> # By default, the model params will be in fp32, to illustrate the use of this method, + >>> # we'll first cast to fp16 and back to fp32 + >>> model.params = model.to_f16(model.params) + >>> # now cast back to fp32 + >>> model.params = model.to_fp32(model.params) + ```""" + return self._cast_floating_to(params, jnp.float32, mask) + + def to_fp16(self, params: Union[dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the + `params` in place. + + This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full + half-precision training or to save weights in float16 for inference in order to save memory and improve speed. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params + you want to cast, and should be `False` for those you want to skip + + Examples: + + ```python + >>> from transformers import FlaxBertModel + + >>> # load model + >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased") + >>> # By default, the model params will be in fp32, to cast these to float16 + >>> model.params = model.to_fp16(model.params) + >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) + >>> # then pass the mask as follows + >>> from flax import traverse_util + + >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased") + >>> flat_params = traverse_util.flatten_dict(model.params) + >>> mask = { + ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) + ... for path in flat_params + ... } + >>> mask = traverse_util.unflatten_dict(mask) + >>> model.params = model.to_fp16(model.params, mask) + ```""" + return self._cast_floating_to(params, jnp.float16, mask) + + @classmethod + def load_flax_weights(cls, resolved_archive_file): + try: + if resolved_archive_file.endswith(".safetensors"): + state = safe_load_file(resolved_archive_file) + state = unflatten_dict(state, sep=".") + else: + with open(resolved_archive_file, "rb") as state_f: + state = from_bytes(cls, state_f.read()) + except (UnpicklingError, msgpack.exceptions.ExtraData) as e: + try: + with open(resolved_archive_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please" + " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" + " folder you cloned." + ) + else: + raise ValueError from e + except (UnicodeDecodeError, ValueError): + raise OSError(f"Unable to convert {resolved_archive_file} to Flax deserializable object. ") + + return state + + @classmethod + def load_flax_sharded_weights(cls, shard_files): + """ + This is the same as [`flax.serialization.from_bytes`] + (https:lax.readthedocs.io/en/latest/_modules/flax/serialization.html#from_bytes) but for a sharded checkpoint. + + This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being + loaded in the model. + + Args: + shard_files (`list[str]`: + The list of shard files to load. + + Returns: + `Dict`: A nested dictionary of the model parameters, in the expected format for flax models : `{'model': + {'params': {'...'}}}`. + """ + + # Load the index + state_sharded_dict = {} + + for shard_file in shard_files: + # load using msgpack utils + try: + with open(shard_file, "rb") as state_f: + state = from_bytes(cls, state_f.read()) + except (UnpicklingError, msgpack.exceptions.ExtraData) as e: + with open(shard_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please" + " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" + " folder you cloned." + ) + else: + raise ValueError from e + except (UnicodeDecodeError, ValueError): + raise OSError(f"Unable to convert {shard_file} to Flax deserializable object. ") + + state = flatten_dict(state, sep="/") + state_sharded_dict.update(state) + del state + gc.collect() + + # the state dict is unflattened to the match the format of model.params + return unflatten_dict(state_sharded_dict, sep="/") + + @classmethod + def can_generate(cls) -> bool: + """ + Returns whether this model can generate sequences with `.generate()`. Returns: + `bool`: Whether this model can generate sequences with `.generate()`. + """ + # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation. + # Alternatively, the model can also have a custom `generate` function. + if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate): + return False + return True + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + dtype: jnp.dtype = jnp.float32, + *model_args, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + **kwargs, + ): + r""" + Instantiate a pretrained flax model from a pre-trained model configuration. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case, + `from_pt` should be set to `True`. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. + model_args (sequence of positional arguments, *optional*): + All remaining positional arguments will be passed to the underlying model's `__init__` method. + config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*): + Can be either: + + - an instance of a class derived from [`PretrainedConfig`], + - a string or path valid as input to [`~PretrainedConfig.from_pretrained`]. + + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the *model id* string of a pretrained + model). + - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the + save directory. + - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a + configuration JSON file named *config.json* is found in the directory. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_pt (`bool`, *optional*, defaults to `False`): + Load the model weights from a PyTorch checkpoint save file (see docstring of + `pretrained_model_name_or_path` argument). + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): + Whether or not to raise an error if some of the weights from the checkpoint do not have the same size + as the weights of the model (if for instance, you are instantiating a model with 10 labels from a + checkpoint with 3 labels). + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `hf auth login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + + To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. + + + + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + automatically loaded: + + - If a configuration is provided with `config`, `**kwargs` will be directly passed to the + underlying model's `__init__` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, `kwargs` will be first passed to the configuration class + initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that + corresponds to a configuration attribute will be used to override said attribute with the + supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute + will be passed to the underlying model's `__init__` function. + + Examples: + + ```python + >>> from transformers import BertConfig, FlaxBertModel + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased") + >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). + >>> model = FlaxBertModel.from_pretrained("./test/saved_model/") + >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). + >>> config = BertConfig.from_json_file("./pt_model/config.json") + >>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config) + ```""" + from_pt = kwargs.pop("from_pt", False) + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + use_auth_token = kwargs.pop("use_auth_token", None) + trust_remote_code = kwargs.pop("trust_remote_code", None) + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + _do_init = kwargs.pop("_do_init", True) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + + # Not relevant for Flax Models + _ = kwargs.pop("adapter_kwargs", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if trust_remote_code is True: + logger.warning( + "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" + " ignored." + ) + + user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + # Load config if we don't provide a configuration + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else pretrained_model_name_or_path + config, model_kwargs = cls.config_class.from_pretrained( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + _commit_hash=commit_hash, + **kwargs, + ) + else: + model_kwargs = kwargs.copy() + + if commit_hash is None: + commit_hash = getattr(config, "_commit_hash", None) + + # Add the dtype to model_kwargs + model_kwargs["dtype"] = dtype + + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # index of the files. + is_sharded = False + + # Load model + if pretrained_model_name_or_path is not None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: + if os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)): + # Load from a Flax checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)): + # Load from a sharded Flax checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME) + is_sharded = True + elif is_safetensors_available() and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME) + elif is_safetensors_available() and os.path.isfile( + os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) + elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): + # Load from a PyTorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) + elif from_pt and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME) + ): + # Load from a sharded pytorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME) + is_sharded = True + # At this stage we don't have a weight file so we will raise an error. + elif is_safetensors_available() and os.path.isfile( + os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + ): + # Load from a sharded safetensors checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + is_sharded = True + raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!") + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): + raise OSError( + f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " + "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those " + "weights." + ) + else: + raise OSError( + f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " + f"{pretrained_model_name_or_path}." + ) + elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): + archive_file = pretrained_model_name_or_path + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + filename = pretrained_model_name_or_path + resolved_archive_file = download_url(pretrained_model_name_or_path) + else: + if from_pt: + filename = WEIGHTS_NAME + else: + filename = FLAX_WEIGHTS_NAME + + try: + # Load from URL or cache if already cached + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "token": token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME: + resolved_archive_file = cached_file( + pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs + ) + if resolved_archive_file is not None: + is_sharded = True + + # Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case. + if resolved_archive_file is None and from_pt: + resolved_archive_file = cached_file( + pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs + ) + if resolved_archive_file is not None: + is_sharded = True + + # If we still haven't found anything, look for `safetensors`. + if resolved_archive_file is None: + # No support for sharded safetensors yet, so we'll raise an error if that's all we find. + filename = SAFE_WEIGHTS_NAME + resolved_archive_file = cached_file( + pretrained_model_name_or_path, SAFE_WEIGHTS_NAME, **cached_file_kwargs + ) + + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None + # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None: + # Otherwise, maybe there is a TF or Torch model file. We try those to give a helpful error + # message. + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + "cache_dir": cache_dir, + "local_files_only": local_files_only, + } + if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs): + is_sharded = True + raise NotImplementedError( + "Support for sharded checkpoints using safetensors is coming soon!" + ) + elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): + raise OSError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to" + " load this model from those weights." + ) + elif has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs): + raise OSError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use" + " `from_pt=True` to load this model from those weights." + ) + else: + raise OSError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." + ) + except OSError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + except Exception: + # For any other exception, we throw a generic error. + raise OSError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." + ) + + if is_local: + logger.info(f"loading weights file {archive_file}") + resolved_archive_file = archive_file + filename = resolved_archive_file.split(os.path.sep)[-1] + else: + logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") + else: + resolved_archive_file = None + + # We'll need to download and cache each checkpoint shard if the checkpoint is sharded. + if is_sharded: + # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. + resolved_archive_file, _ = get_checkpoint_shard_files( + pretrained_model_name_or_path, + resolved_archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _commit_hash=commit_hash, + ) + + safetensors_from_pt = False + if filename == SAFE_WEIGHTS_NAME: + with safe_open(resolved_archive_file, framework="flax") as f: + safetensors_metadata = f.metadata() + if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]: + raise OSError( + f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata." + " Make sure you save your model with the `save_pretrained` method." + ) + safetensors_from_pt = safetensors_metadata.get("format") == "pt" + + # init random models + model = cls(config, *model_args, _do_init=_do_init, **model_kwargs) + + if from_pt or safetensors_from_pt: + state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded) + else: + if is_sharded: + state = cls.load_flax_sharded_weights(resolved_archive_file) + else: + state = cls.load_flax_weights(resolved_archive_file) + # make sure all arrays are stored as jnp.arrays + # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: + # https://github.com/google/flax/issues/1261 + if _do_init: + state = jax.tree_util.tree_map(jnp.array, state) + else: + # keep the params on CPU if we don't want to initialize + state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state) + + if "batch_stats" in state: # if flax model contains batch norm layers + # if model is base model only use model_prefix key + if ( + cls.base_model_prefix not in dict(model.params_shape_tree["params"]) + and cls.base_model_prefix in state["params"] + ): + state["params"] = state["params"][cls.base_model_prefix] + state["batch_stats"] = state["batch_stats"][cls.base_model_prefix] + + # if model is head model and we are loading weights from base model + # we initialize new params dict with base_model_prefix + if ( + cls.base_model_prefix in dict(model.params_shape_tree["params"]) + and cls.base_model_prefix not in state["params"] + ): + state = { + "params": {cls.base_model_prefix: state["params"]}, + "batch_stats": {cls.base_model_prefix: state["batch_stats"]}, + } + + else: + # if model is base model only use model_prefix key + if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state: + state = state[cls.base_model_prefix] + + # if model is head model and we are loading weights from base model + # we initialize new params dict with base_model_prefix + if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state: + state = {cls.base_model_prefix: state} + + # flatten dicts + state = flatten_dict(state) + + random_state = flatten_dict(unfreeze(model.params if _do_init else model.params_shape_tree)) + + missing_keys = model.required_params - set(state.keys()) + unexpected_keys = set(state.keys()) - model.required_params + + # Disabling warning when porting pytorch weights to flax, flax does not uses num_batches_tracked + for unexpected_key in unexpected_keys.copy(): + if "num_batches_tracked" in unexpected_key[-1]: + unexpected_keys.remove(unexpected_key) + + if missing_keys and not _do_init: + logger.warning( + f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. " + "Make sure to call model.init_weights to initialize the missing weights." + ) + cls._missing_keys = missing_keys + + # Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not + # matching the weights in the model. + mismatched_keys = [] + for key in state: + if key in random_state and state[key].shape != random_state[key].shape: + if ignore_mismatched_sizes: + mismatched_keys.append((key, state[key].shape, random_state[key].shape)) + state[key] = random_state[key] + else: + raise ValueError( + f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " + f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. " + "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this " + "model." + ) + + # add missing keys as random parameters if we are initializing + if missing_keys and _do_init: + for missing_key in missing_keys: + state[missing_key] = random_state[missing_key] + + # remove unexpected keys to not be saved again + for unexpected_key in unexpected_keys: + del state[unexpected_key] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" + " with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" + " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" + f" was trained on, you can already use {model.__class__.__name__} for predictions without further" + " training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" + " to use it for predictions and inference." + ) + + # dictionary of key: dtypes for the model params + param_dtypes = jax.tree_util.tree_map(lambda x: x.dtype, state) + # extract keys of parameters not in jnp.float32 + fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16] + bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16] + + # raise a warning if any of the parameters are not in jnp.float32 + if len(fp16_params) > 0: + logger.warning( + f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from " + f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n" + "You should probably UPCAST the model weights to float32 if this was not intended. " + "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this." + ) + + if len(bf16_params) > 0: + logger.warning( + f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from " + f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n" + "You should probably UPCAST the model weights to float32 if this was not intended. " + "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this." + ) + + # If it is a model with generation capabilities, attempt to load the generation config + if model.can_generate(): + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + except OSError: + logger.info( + "Generation config file not found, using a generation config created from the model config." + ) + pass + + if _do_init: + # set correct parameters + model.params = unflatten_dict(state) + return model + else: + return model, unflatten_dict(state) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + params=None, + push_to_hub=False, + max_shard_size="10GB", + token: Optional[Union[str, bool]] = None, + safe_serialization: bool = False, + **kwargs, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~FlaxPreTrainedModel.from_pretrained`]` class method + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). + + + + If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard + which will be bigger than `max_shard_size`. + + + + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `hf auth login` (stored in `~/.huggingface`). + kwargs (`dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save the model using `safetensors` or through msgpack. + """ + use_auth_token = kwargs.pop("use_auth_token", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None: + kwargs["token"] = token + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + + # get abs dir + save_directory = os.path.abspath(save_directory) + # save config as well + self.config.architectures = [self.__class__.__name__[4:]] + + # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + custom_object_save(self, save_directory, config=self.config) + + self.config.save_pretrained(save_directory) + if self.can_generate(): + self.generation_config.save_pretrained(save_directory) + + # save model + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else FLAX_WEIGHTS_NAME + output_model_file = os.path.join(save_directory, weights_name) + + shards, index = flax_shard_checkpoint(params if params is not None else self.params, max_shard_size) + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") + if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and filename not in shards: + os.remove(full_filename) + + if index is None: + if safe_serialization: + params = params if params is not None else self.params + flat_dict = flatten_dict(params, sep=".") + safe_save_file(flat_dict, output_model_file, metadata={"format": "flax"}) + else: + with open(output_model_file, "wb") as f: + params = params if params is not None else self.params + model_bytes = to_bytes(params) + f.write(model_bytes) + + else: + save_index_file = os.path.join(save_directory, FLAX_WEIGHTS_INDEX_NAME) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + for shard_file, shard in shards.items(): + # the shard item are unflattened, to save them we need to flatten them again + with open(os.path.join(save_directory, shard_file), mode="wb") as f: + params = unflatten_dict(shard, sep="/") + shard_bytes = to_bytes(params) + f.write(shard_bytes) + + logger.info(f"Model weights saved in {output_model_file}") + + if push_to_hub: + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=token, + ) + + @classmethod + def register_for_auto_class(cls, auto_class="FlaxAutoModel"): + """ + Register this class with a given auto class. This should only be used for custom models as the ones in the + library are already mapped with an auto class. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"FlaxAutoModel"`): + The auto class to register this new model with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class + + +# To update the docstring, we need to copy the method, otherwise we change the original docstring. +FlaxPreTrainedModel.push_to_hub = copy_func(FlaxPreTrainedModel.push_to_hub) +if FlaxPreTrainedModel.push_to_hub.__doc__ is not None: + FlaxPreTrainedModel.push_to_hub.__doc__ = FlaxPreTrainedModel.push_to_hub.__doc__.format( + object="model", object_class="FlaxAutoModel", object_files="model checkpoint" + ) + + +def overwrite_call_docstring(model_class, docstring): + # copy __call__ function to be sure docstring is changed only for this function + model_class.__call__ = copy_func(model_class.__call__) + # delete existing docstring + model_class.__call__.__doc__ = None + # set correct docstring + model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__) + + +def append_call_sample_docstring( + model_class, checkpoint, output_type, config_class, mask=None, revision=None, real_checkpoint=None +): + model_class.__call__ = copy_func(model_class.__call__) + model_class.__call__ = add_code_sample_docstrings( + checkpoint=checkpoint, + output_type=output_type, + config_class=config_class, + model_cls=model_class.__name__, + revision=revision, + real_checkpoint=real_checkpoint, + )(model_class.__call__) + + +def append_replace_return_docstrings(model_class, output_type, config_class): + model_class.__call__ = copy_func(model_class.__call__) + model_class.__call__ = replace_return_docstrings( + output_type=output_type, + config_class=config_class, + )(model_class.__call__) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_gguf_pytorch_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_gguf_pytorch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..08aaac3617ff65b5f9e8306bb04ea454e20be5ee --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_gguf_pytorch_utils.py @@ -0,0 +1,532 @@ +# Copyright 2024 The ggml.ai team and The HuggingFace Inc. team. and pygguf author (github.com/99991) +# https://github.com/99991/pygguf +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import NamedTuple, Optional + +import numpy as np +from tqdm.auto import tqdm + +from .integrations import ( + GGUF_CONFIG_MAPPING, + GGUF_TOKENIZER_MAPPING, + _gguf_parse_value, +) +from .utils import is_torch_available +from .utils.import_utils import is_gguf_available +from .utils.logging import get_logger + + +if is_torch_available(): + import torch + +logger = get_logger(__name__) + + +GGUF_TO_TRANSFORMERS_MAPPING = { + "ignore": { + "GGUF": { + "version": "version", + "tensor_count": "tensor_count", + "kv_count": "kv_count", + }, + "general": {"file_type": "file_type", "quantization_version": "quantization_version"}, + }, + "config": GGUF_CONFIG_MAPPING, + "tokenizer": {"tokenizer": GGUF_TOKENIZER_MAPPING["tokenizer"]}, + "tokenizer_config": {"tokenizer": GGUF_TOKENIZER_MAPPING["tokenizer_config"]}, +} + +GGUF_SUPPORTED_ARCHITECTURES = list(GGUF_TO_TRANSFORMERS_MAPPING["config"].keys()) + + +class GGUFTensor(NamedTuple): + weights: np.ndarray + name: str + metadata: dict + + +class TensorProcessor: + def __init__(self, config=None): + self.config = config or {} + + def process(self, weights, name, **kwargs): + return GGUFTensor(weights, name, {}) + + +class LlamaTensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name, **kwargs): + if ".attn_k." in name or ".attn_q." in name: + num_heads = self.config.get("num_attention_heads") + num_kv_heads = self.config.get("num_key_value_heads") + + if None in (num_heads, num_kv_heads): + return GGUFTensor(weights, name, {}) + if ".attn_q." in name: + weights = self._reverse_permute_weights(weights, num_heads, num_heads) + elif ".attn_k." in name: + weights = self._reverse_permute_weights(weights, num_heads, num_kv_heads) + return GGUFTensor(weights, name, {}) + + def _reverse_permute_weights( + self, weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None + ) -> np.ndarray: + # Original permutation implementation + # https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408 + if num_kv_heads is not None and n_head != num_kv_heads: + n_head = num_kv_heads + + dim = weights.shape[0] // n_head // 2 + w = weights.reshape(n_head, dim, 2, *weights.shape[1:]) + return w.swapaxes(2, 1).reshape(weights.shape) + + +class Qwen2MoeTensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name, **kwargs): + if "_exp" in name: + tensor_key_mapping = kwargs.get("tensor_key_mapping") + parsed_parameters = kwargs.get("parsed_parameters") + if tensor_key_mapping: + self._split_moe_expert_tensor(weights, parsed_parameters, name, tensor_key_mapping) + return GGUFTensor(weights, None, {}) + if "ffn_gate_inp_shexp" in name: + # for compatibility tensor shared_expert_gate must be (1, 2048) dim, + # quantized one is (2048) + weights = np.expand_dims(weights, axis=0) + return GGUFTensor(weights, name, {}) + + def _split_moe_expert_tensor( + self, weights: np.ndarray, parsed_parameters: dict[str, dict], name: str, tensor_key_mapping: dict + ): + # Original merge implementation + # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L1994-L2022 + name = tensor_key_mapping[name] + w_counter = self.config.get("num_experts", 60) + for i in range(0, w_counter): + temp_name = name.replace("mlp.experts.", f"mlp.experts.{i}.") + exp_weight = weights[i] + parsed_parameters["tensors"][temp_name] = torch.from_numpy(np.copy(exp_weight)) + + +class BloomTensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name, **kwargs): + if "attn_qkv" in name: + num_heads = self.config["n_head"] + n_embed = self.config["hidden_size"] + if "weight" in name: + weights = self._reverse_reshape_weights(weights, num_heads, n_embed) + else: + weights = self._reverse_reshape_bias(weights, num_heads, n_embed) + return GGUFTensor(weights, name, {}) + + def _reverse_reshape_weights(self, weights: np.ndarray, n_head: int, n_embed: int): + # Original reshape implementation + # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L972-L985 + q, k, v = np.array_split(weights, 3, axis=0) + + q = q.reshape(n_head, n_embed // n_head, n_embed) + k = k.reshape(n_head, n_embed // n_head, n_embed) + v = v.reshape(n_head, n_embed // n_head, n_embed) + qkv_weights = np.stack([q, k, v], axis=1) + + return qkv_weights.reshape(n_head * 3 * (n_embed // n_head), n_embed) + + def _reverse_reshape_bias(self, weights: np.ndarray, n_head: int, n_embed: int): + # Original reshape implementation + # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L986-L998 + q_bias, k_bias, v_bias = np.array_split(weights, 3) + + q_bias = q_bias.reshape(n_head, n_embed // n_head) + k_bias = k_bias.reshape(n_head, n_embed // n_head) + v_bias = v_bias.reshape(n_head, n_embed // n_head) + + qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten() + return qkv_bias + + +class T5TensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name, **kwargs): + bid = None + for chunk in name.split("."): + if chunk.isdigit(): + bid = int(chunk) + break + return GGUFTensor(weights, name, {"bid": bid}) + + +class GPT2TensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name, **kwargs): + # Original transpose implementation + # https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L2060-L2061 + if ( + "attn_qkv.weight" in name + or "ffn_down.weight" in name + or "ffn_up.weight" in name + or "attn_output.weight" in name + ): + weights = weights.T + + # Handle special case for output.weight + if name == "output.weight": + # output.weight has conflicts with attn_output.weight in name checking + # Store the tensor directly and signal to skip further processing + name = "lm_head.weight" + parsed_parameters = kwargs.get("parsed_parameters", {}) + parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights)) + name = None # Signal to skip further processing + return GGUFTensor(weights, name, {}) + + +class MambaTensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name, **kwargs): + if "ssm_conv1d.weight" in name: + # for compatibility tensor ssm_conv1d must be (5120, 1, 4]) dim, + # quantized one is (5120, 4) + weights = np.expand_dims(weights, axis=1) + if "ssm_a" in name: + # Original exponential implementation + # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L2975-L2977 + weights = np.log(-weights) + return GGUFTensor(weights, name, {}) + + +class NemotronTensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + # ref : https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L4666 + def process(self, weights, name, **kwargs): + if "norm.weight" in name: + weights = weights - 1 + return GGUFTensor(weights, name, {}) + + +class Gemma2TensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + # ref: https://github.com/ggerganov/llama.cpp/blob/d79d8f39b4da6deca4aea8bf130c6034c482b320/convert_hf_to_gguf.py#L3191 + # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89 + def process(self, weights, name, **kwargs): + if "norm.weight" in name: + weights = weights - 1 + return GGUFTensor(weights, name, {}) + + +class Lfm2TensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name, **kwargs): + if "shortconv.conv.weight" in name: + ## GGUF shape is [hidden_dim, L_cache], HF expects [hidden_dim, 1, L_cache] + weights = np.expand_dims(weights, axis=1) ## equivalent to unsqueeze(1) + return GGUFTensor(weights, name, {}) + + +TENSOR_PROCESSORS = { + "llama": LlamaTensorProcessor, + "qwen2moe": Qwen2MoeTensorProcessor, + "qwen3moe": Qwen2MoeTensorProcessor, + "bloom": BloomTensorProcessor, + "t5": T5TensorProcessor, + "t5encoder": T5TensorProcessor, + "gpt2": GPT2TensorProcessor, + "mamba": MambaTensorProcessor, + "nemotron": NemotronTensorProcessor, + "gemma2": Gemma2TensorProcessor, + "gemma3": Gemma2TensorProcessor, + "lfm2": Lfm2TensorProcessor, +} + + +def read_field(reader, field): + if field not in reader.fields: + return [] + value = reader.fields[field] + return [_gguf_parse_value(value.parts[_data_index], value.types) for _data_index in value.data] + + +# modified from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/loader.py#L1115-L1147 +def get_gguf_hf_weights_map( + hf_model, + model_type: Optional[str] = None, + num_layers: Optional[int] = None, + qual_name: str = "", +): + """ + GGUF uses this naming convention for their tensors from HF checkpoint: + `blk.N.BB.weight` and `blk.N.BB.bias` + where N signifies the block number of a layer, and BB signifies the + attention/mlp layer components. + See "Standardized tensor names" in + https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details. + """ + if is_gguf_available() and is_torch_available(): + from gguf import MODEL_ARCH_NAMES, get_tensor_name_map + else: + logger.error( + "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see " + "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions." + ) + raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.") + + model_type = hf_model.config.model_type if model_type is None else model_type + num_layers = hf_model.config.num_hidden_layers if num_layers is None else num_layers + # hack: ggufs have a different name for cohere + if model_type == "cohere": + model_type = "command-r" + elif model_type == "qwen2_moe": + model_type = "qwen2moe" + elif model_type == "qwen3_moe": + model_type = "qwen3moe" + elif model_type == "gemma3_text": + model_type = "gemma3" + elif model_type == "umt5": + model_type = "t5" + arch = None + for key, value in MODEL_ARCH_NAMES.items(): + if value == model_type: + arch = key + break + if arch is None: + raise NotImplementedError( + f"Unknown gguf model_type: {model_type} in gguf-py. " + "This might because you're using an outdated version of gguf-py package, " + "you can install `gguf` package from source refer to " + "https://github.com/ggerganov/llama.cpp/tree/master/gguf-py#development" + ) + name_map = get_tensor_name_map(arch, num_layers) + + # Use a dummy conversion to get the mapping, because + # hf => gguf and gguf => hf mappings are reversed + gguf_to_hf_name_map = {} + state_dict = hf_model.state_dict() + for hf_name in state_dict: + # An exception for qwen2moe/qwen3moe model, where the expert layers are packed + if model_type in ("qwen2moe", "qwen3moe") and "mlp.experts." in hf_name: + hf_name = re.sub(r"mlp.experts.\d+.", "mlp.experts.", hf_name) + + name, suffix = hf_name, "" + if hf_name.endswith(".weight") or hf_name.endswith(".bias"): + name, suffix = hf_name.rsplit(".", 1) + suffix = "." + suffix + + gguf_name = name_map.get_name(name) + if gguf_name is None: + continue + + gguf_to_hf_name_map[gguf_name + suffix] = qual_name + hf_name + + # Some model like Bloom converted from BloomModel instead of BloomForCausalLM + # Therefore, we need to check submodule as well to get a correct mapping + if named_children := hf_model.named_children(): + for name, child in named_children: + sub_map = get_gguf_hf_weights_map(child, model_type, num_layers, qual_name=f"{qual_name}{name}.") + # Ignore the keys that are already in the main map to avoid overwriting + sub_map = {k: v for k, v in sub_map.items() if k not in gguf_to_hf_name_map} + gguf_to_hf_name_map.update(sub_map) + + return gguf_to_hf_name_map + + +def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_load=None): + """ + Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed + tokenizer and config attributes. + + Args: + gguf_checkpoint_path (`str`): + The path the to GGUF file to load + return_tensors (`bool`, defaults to `False`): + Whether to read the tensors from the file and return them. Not doing so is faster + and only loads the metadata in memory. + """ + if is_gguf_available() and is_torch_available(): + from gguf import GGUFReader, dequantize + else: + logger.error( + "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see " + "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions." + ) + raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.") + + reader = GGUFReader(gguf_checkpoint_path) + fields = reader.fields + reader_keys = list(fields.keys()) + + parsed_parameters = {k: {} for k in GGUF_TO_TRANSFORMERS_MAPPING} + + architecture = read_field(reader, "general.architecture")[0] + # NOTE: Some GGUF checkpoints may miss `general.name` field in metadata + model_name = read_field(reader, "general.name") + + updated_architecture = None + # in llama.cpp mistral models use the same architecture as llama. We need + # to add this patch to ensure things work correctly on our side. + if "llama" in architecture and "mistral" in model_name: + updated_architecture = "mistral" + # FIXME: Currently this implementation is only for flan-t5 architecture. + # It needs to be developed for supporting legacy t5. + elif "t5" in architecture or "t5encoder" in architecture: + parsed_parameters["config"]["is_gated_act"] = True + if model_name and "umt5" in model_name[0].lower(): + updated_architecture = "umt5" + if "t5encoder" in architecture: + parsed_parameters["config"]["architectures"] = ["UMT5EncoderModel"] + else: + if "t5encoder" in architecture: + parsed_parameters["config"]["architectures"] = ["T5EncoderModel"] + updated_architecture = "t5" + else: + updated_architecture = architecture + + if "qwen2moe" in architecture: + updated_architecture = "qwen2_moe" + elif "qwen3moe" in architecture: + updated_architecture = "qwen3_moe" + + # For stablelm architecture, we need to set qkv_bias and use_parallel_residual from tensors + # If `qkv_bias=True`, qkv_proj with bias will be present in the tensors + # If `use_parallel_residual=False`, ffn_norm will be present in the tensors + if "stablelm" in architecture: + attn_bias_name = {"attn_q.bias", "attn_k.bias", "attn_v.bias"} + ffn_norm_name = "ffn_norm" + qkv_bias = any(bias_name in tensor.name for tensor in reader.tensors for bias_name in attn_bias_name) + use_parallel_residual = any(ffn_norm_name in tensor.name for tensor in reader.tensors) + parsed_parameters["config"]["use_qkv_bias"] = qkv_bias + parsed_parameters["config"]["use_parallel_residual"] = not use_parallel_residual + + if architecture not in GGUF_SUPPORTED_ARCHITECTURES and updated_architecture not in GGUF_SUPPORTED_ARCHITECTURES: + raise ValueError(f"GGUF model with architecture {architecture} is not supported yet.") + + # Handle tie_word_embeddings, if lm_head.weight is not present in tensors, + # tie_word_embeddings is true otherwise false + exceptions = ["falcon", "bloom"] + parsed_parameters["config"]["tie_word_embeddings"] = ( + all("output.weight" != tensor.name for tensor in reader.tensors) or architecture in exceptions + ) + + # List all key-value pairs in a columnized format + for gguf_key, field in reader.fields.items(): + gguf_key = gguf_key.replace(architecture, updated_architecture) + split = gguf_key.split(".") + prefix = split[0] + config_key = ".".join(split[1:]) + + value = [_gguf_parse_value(field.parts[_data_index], field.types) for _data_index in field.data] + + if len(value) == 1: + value = value[0] + + if isinstance(value, str) and architecture in value: + value = value.replace(architecture, updated_architecture) + + for parameter, parameter_renames in GGUF_TO_TRANSFORMERS_MAPPING.items(): + if prefix in parameter_renames and config_key in parameter_renames[prefix]: + renamed_config_key = parameter_renames[prefix][config_key] + if renamed_config_key == -1: + continue + + if renamed_config_key is not None: + parsed_parameters[parameter][renamed_config_key] = value + + if gguf_key in reader_keys: + reader_keys.remove(gguf_key) + + if gguf_key in reader_keys: + logger.info(f"Some keys were not parsed and added into account {gguf_key} | {value}") + + # Gemma3 GGUF checkpoint only contains weights of text backbone + if parsed_parameters["config"]["model_type"] == "gemma3": + parsed_parameters["config"]["model_type"] = "gemma3_text" + + if parsed_parameters["config"]["model_type"] == "lfm2": + gguf_num_key_value_heads = parsed_parameters["config"]["num_key_value_heads"] + # LFM2 GGUF checkpoint defines num_key_value_heads as a list of integers .e.g [0, 0, 8, 0, 0, 8, 0, 0, 8, 0, 8, 0, 8, 0, 8, 0] but we need to set it to the max value for HF + parsed_parameters["config"]["num_key_value_heads"] = max(gguf_num_key_value_heads) + ## we already read the correct intermediate_size from the GGUF checkpoint so we need to set block_auto_adjust_ff_dim to False + parsed_parameters["config"]["block_auto_adjust_ff_dim"] = False + + ## llama.cpp defines the layers that are full-attention by looking at num_key_value_heads + ## we need to set the full_attn_idxs to the layers that are full-attention + parsed_parameters["config"]["full_attn_idxs"] = [ + i for i, num_kv_heads in enumerate(gguf_num_key_value_heads) if num_kv_heads > 0 + ] + + # retrieve config vocab_size from tokenizer + # Please refer to https://github.com/huggingface/transformers/issues/32526 for more details + if "vocab_size" not in parsed_parameters["config"]: + tokenizer_parameters = parsed_parameters["tokenizer"] + if "tokens" in tokenizer_parameters: + parsed_parameters["config"]["vocab_size"] = len(tokenizer_parameters["tokens"]) + else: + logger.warning( + "Can't find a way to retrieve missing config vocab_size from tokenizer parameters. " + "This will use default value from model config class and cause unexpected behavior." + ) + + if return_tensors: + parsed_parameters["tensors"] = {} + + tensor_key_mapping = get_gguf_hf_weights_map(model_to_load) + config = parsed_parameters.get("config", {}) + + ProcessorClass = TENSOR_PROCESSORS.get(architecture, TensorProcessor) + processor = ProcessorClass(config=config) + + for tensor in tqdm(reader.tensors, desc="Converting and de-quantizing GGUF tensors..."): + name = tensor.name + weights = dequantize(tensor.data, tensor.tensor_type) + + result = processor.process( + weights=weights, + name=name, + tensor_key_mapping=tensor_key_mapping, + parsed_parameters=parsed_parameters, + ) + + weights = result.weights + name = result.name + + if name not in tensor_key_mapping: + continue + + name = tensor_key_mapping[name] + + parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights)) + + if len(reader_keys) > 0: + logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}") + + return parsed_parameters diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_outputs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..1747f6fa477b56bc349f2c535705382ebc84af3c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_outputs.py @@ -0,0 +1,1715 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass +from typing import Optional + +import torch + +from .cache_utils import Cache, EncoderDecoderCache +from .utils import ModelOutput + + +@dataclass +class BaseModelOutput(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithNoAttention(ModelOutput): + """ + Base class for model's outputs, with potential hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPooling(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPoolingAndNoAttention(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state after a pooling operation on the spatial dimensions. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPast(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithCrossAttentions(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + past_key_values: Optional[Cache] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class MoECausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs as well as Mixture of Expert's router hidden + states terms, to train a MoE model. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + z_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + z_loss for the sparse modules. + aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + aux_loss for the sparse modules. + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse + modules. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + z_loss: Optional[torch.FloatTensor] = None + aux_loss: Optional[torch.FloatTensor] = None + router_logits: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +class MoEModelOutput(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary + loss and the z_loss for Mixture of Experts models. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + router_probs: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +class MoeModelOutputWithPast(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + router_logits: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +class MoeCausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) with mixture of experts outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + + aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + aux_loss for the sparse modules. + + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + aux_loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + router_logits: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +class MoEModelOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding) as well as + Mixture of Expert's router hidden states terms, to train a MoE model. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary + loss and the z_loss for Mixture of Experts models. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + router_probs: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +class Seq2SeqModelOutput(ModelOutput): + """ + Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[EncoderDecoderCache] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqMoEModelOutput(ModelOutput): + """ + Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse + modules. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[EncoderDecoderCache] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + decoder_router_logits: Optional[tuple[torch.FloatTensor]] = None + cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_router_logits: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +class CausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class CausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class CausalLMOutputWithCrossAttentions(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Cross attentions weights after the attention softmax, used to compute the weighted average in the + cross-attention heads. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class SequenceClassifierOutputWithPast(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class MaskedLMOutput(ModelOutput): + """ + Base class for masked language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Masked language modeling (MLM) loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqLMOutput(ModelOutput): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[EncoderDecoderCache] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqMoEOutput(ModelOutput): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the encoder model, useful to compute the auxiliary loss and z_loss for Mixture of Experts + models. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + encoder_z_loss: Optional[torch.FloatTensor] = None + decoder_z_loss: Optional[torch.FloatTensor] = None + encoder_aux_loss: Optional[torch.FloatTensor] = None + decoder_aux_loss: Optional[torch.FloatTensor] = None + past_key_values: Optional[EncoderDecoderCache] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + decoder_router_logits: Optional[tuple[torch.FloatTensor]] = None + cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_router_logits: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +class NextSentencePredictorOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `next_sentence_label` is provided): + Next sequence prediction (classification) loss. + logits (`torch.FloatTensor` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class SequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqSequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of sequence-to-sequence sentence classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `label` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[EncoderDecoderCache] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class MultipleChoiceModelOutput(ModelOutput): + """ + Base class for outputs of multiple choice models. + + Args: + loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided): + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + *num_choices* is the second dimension of the input tensors. (see *input_ids* above). + + Classification scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class TokenClassifierOutput(ModelOutput): + """ + Base class for outputs of token classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) : + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Classification scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class QuestionAnsweringModelOutput(ModelOutput): + """ + Base class for outputs of question answering models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + start_logits: Optional[torch.FloatTensor] = None + end_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqQuestionAnsweringModelOutput(ModelOutput): + """ + Base class for outputs of sequence-to-sequence question answering models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + start_logits: Optional[torch.FloatTensor] = None + end_logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[EncoderDecoderCache] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class SemanticSegmenterOutput(ModelOutput): + """ + Base class for outputs of semantic segmentation models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): + Classification scores for each pixel. + + + + The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is + to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the + original image size as post-processing. You should always check your logits shape and resize as needed. + + + + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, patch_size, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class ImageClassifierOutput(ModelOutput): + """ + Base class for outputs of image classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states + (also called feature maps) of the model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class ImageClassifierOutputWithNoAttention(ModelOutput): + """ + Base class for outputs of image classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also + called feature maps) of the model at the output of each stage. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class DepthEstimatorOutput(ModelOutput): + """ + Base class for outputs of depth estimation models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + predicted_depth (`torch.FloatTensor` of shape `(batch_size, height, width)`): + Predicted depth for each pixel. + + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + predicted_depth: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class ImageSuperResolutionOutput(ModelOutput): + """ + Base class for outputs of image super resolution models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Reconstruction loss. + reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed images, possibly upscaled. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states + (also called feature maps) of the model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + reconstruction: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Wav2Vec2BaseModelOutput(ModelOutput): + """ + Base class for models that have been trained with the Wav2Vec2 loss objective. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`): + Sequence of extracted feature vectors of the last convolutional layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + extract_features: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class XVectorOutput(ModelOutput): + """ + Output type of [`Wav2Vec2ForXVector`]. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`): + Classification hidden states before AMSoftmax. + embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`): + Utterance embeddings used for vector similarity-based retrieval. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + embeddings: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BackboneOutput(ModelOutput): + """ + Base class for outputs of backbones. + + Args: + feature_maps (`tuple(torch.FloatTensor)` of shape `(batch_size, num_channels, height, width)`): + Feature maps of the stages. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, num_channels, height, width)`, + depending on the backbone. + + Hidden-states of the model at the output of each stage plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Only applicable if the backbone uses attention. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + feature_maps: Optional[tuple[torch.FloatTensor]] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPoolingAndProjection(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + projection_state (`tuple(torch.FloatTensor)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` of shape `(batch_size,config.project_dim)`. + + Text embeddings before the projection layer, used to mimic the last hidden state of the teacher encoder. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + projection_state: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +class Seq2SeqSpectrogramOutput(ModelOutput): + """ + Base class for sequence-to-sequence spectrogram outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Spectrogram generation loss. + spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`): + The predicted spectrogram. + past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + spectrogram: Optional[torch.FloatTensor] = None + past_key_values: Optional[EncoderDecoderCache] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Seq2SeqTSModelOutput(ModelOutput): + """ + Base class for time series model's encoder outputs that also contains pre-computed hidden states that can speed up + sequential decoding. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + loc (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): + Shift values of each time series' context window which is used to give the model inputs of the same + magnitude and then used to shift back to the original magnitude. + scale (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): + Scaling values of each time series' context window which is used to give the model inputs of the same + magnitude and then used to rescale back to the original magnitude. + static_features (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*): + Static features of each time series' in a batch which are copied to the covariates at inference time. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[EncoderDecoderCache] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + loc: Optional[torch.FloatTensor] = None + scale: Optional[torch.FloatTensor] = None + static_features: Optional[torch.FloatTensor] = None + + +@dataclass +class Seq2SeqTSPredictionOutput(ModelOutput): + """ + Base class for time series model's decoder outputs that also contain the loss as well as the parameters of the + chosen distribution. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when a `future_values` is provided): + Distributional loss. + params (`torch.FloatTensor` of shape `(batch_size, num_samples, num_params)`): + Parameters of the chosen distribution. + past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + loc (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): + Shift values of each time series' context window which is used to give the model inputs of the same + magnitude and then used to shift back to the original magnitude. + scale (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): + Scaling values of each time series' context window which is used to give the model inputs of the same + magnitude and then used to rescale back to the original magnitude. + static_features (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*): + Static features of each time series' in a batch which are copied to the covariates at inference time. + """ + + loss: Optional[torch.FloatTensor] = None + params: Optional[tuple[torch.FloatTensor, ...]] = None + past_key_values: Optional[EncoderDecoderCache] = None + decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + loc: Optional[torch.FloatTensor] = None + scale: Optional[torch.FloatTensor] = None + static_features: Optional[torch.FloatTensor] = None + + +@dataclass +class SampleTSPredictionOutput(ModelOutput): + """ + Base class for time series model's predictions outputs that contains the sampled values from the chosen + distribution. + + Args: + sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length)` or `(batch_size, num_samples, prediction_length, input_size)`): + Sampled values from the chosen distribution. + """ + + sequences: Optional[torch.FloatTensor] = None + + +@dataclass +class MaskedImageModelingOutput(ModelOutput): + """ + Base class for outputs of masked image completion / in-painting models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): + Reconstruction loss. + reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed / completed images. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or + when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states + (also called feature maps) of the model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + reconstruction: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + @property + def logits(self): + warnings.warn( + "logits attribute is deprecated and will be removed in version 5 of Transformers." + " Please use the reconstruction attribute to retrieve the final output instead.", + FutureWarning, + ) + return self.reconstruction diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_rope_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_rope_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c0070df6ee17ae08c4ddcd6294379bf207867e03 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_rope_utils.py @@ -0,0 +1,773 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import wraps +from typing import Optional + +from .configuration_utils import PretrainedConfig +from .utils import is_torch_available, logging + + +logger = logging.get_logger(__name__) + + +if is_torch_available(): + import torch + + +def dynamic_rope_update(rope_forward): + """ + Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE + (i.e. a RoPE implementation that may recompute its frequencies in the forward pass). + + Args: + rope_forward (Callable): + The forward pass of the RoPE implementation. + + Returns: + The decorated forward pass. + """ + + def longrope_frequency_update(self, position_ids, device): + """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise.""" + seq_len = torch.max(position_ids) + 1 + if hasattr(self.config, "original_max_position_embeddings"): + original_max_position_embeddings = self.config.original_max_position_embeddings + else: + original_max_position_embeddings = self.config.max_position_embeddings + if seq_len > original_max_position_embeddings: + if not hasattr(self, "long_inv_freq"): + self.long_inv_freq, _ = self.rope_init_fn( + self.config, device, seq_len=original_max_position_embeddings + 1 + ) + self.register_buffer("inv_freq", self.long_inv_freq, persistent=False) + else: + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + + def dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @wraps(rope_forward) + def wrapper(self, x, position_ids): + if "dynamic" in self.rope_type: + dynamic_frequency_update(self, position_ids, device=x.device) + elif self.rope_type == "longrope": + longrope_frequency_update(self, position_ids, device=x.device) + return rope_forward(self, x, position_ids) + + return wrapper + + +def _compute_default_rope_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, +) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. This function assumes that the config will provide at least the following + properties: + + * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived. + * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly. + * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly. + + Additionally, this function will make use of the following properties if they are found in the config: + + * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be + derived as hidden_size // num_attention_heads. + * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for + the first fraction of the head_dim. Defaults to 1.0. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_theta + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) + return inv_freq, attention_factor + + +def _compute_linear_scaling_rope_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, +) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. This function assumes that the config will provide at least the following + properties: + + * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived. + * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly. + * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly. + + Additionally, this function will make use of the following properties if they are found in the config: + + * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be + derived as hidden_size // num_attention_heads. + * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for + the first fraction of the head_dim. Defaults to 1.0. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + factor = config.rope_scaling["factor"] + + # Gets the default RoPE parameters + inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len) + + # Then applies linear scaling to the frequencies. + # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so + # applying scaling to the inverse frequencies is equivalent. + inv_freq /= factor + return inv_freq, attention_factor + + +def _compute_dynamic_ntk_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, +) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla + + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. This function assumes that the config will provide at least the following + properties: + + * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived. + * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly. + * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly. + * max_position_embeddings (`int`): The default sequence length used to update the dynamic RoPE at + inference time + * rope_scaling (`dict[str, float]`): The standard RoPE scaling parameters, from which `factor` + will be accessed. The value of `factor` is used to determine the new base frequency, along with the + current sequence length (seq_len), the maximum positional embeddings (max_position_embeddings), and the + computed dimensionality (dim) of the rotary embeddings. If seq_len <= max_position_embeddings, this + factor has no effect. If seq_len <= max_position_embeddings, this factor effectively stretches the + context window using an exponent derived from `dim`. + + Additionally, this function will make use of the following properties if they are found in the config: + + * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be + derived as hidden_size // num_attention_heads. + * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for + the first fraction of the head_dim. Defaults to 1.0. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length, used to update the dynamic RoPE at inference time. If `None` or shorter than + max_position_embeddings, this value will be overridden by max_position_embeddings. + + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling + base = config.rope_theta + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + max_position_embeddings = config.max_position_embeddings + factor = config.rope_scaling["factor"] + + attention_factor = 1.0 # Unused in this type of RoPE + + # seq_len: default to max_position_embeddings, e.g. at init time + if seq_len is None: + seq_len = max_position_embeddings + elif isinstance(seq_len, torch.Tensor): + seq_len = torch.maximum( + seq_len, + torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device), + ) + else: + seq_len = max(seq_len, max_position_embeddings) + + # Compute the inverse frequencies + base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) + return inv_freq, attention_factor + + +def _compute_yarn_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None +) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with NTK scaling. Please refer to the + [original paper](https://huggingface.co/papers/2309.00071) + + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. This function assumes that the config will provide at least the following + properties: + + * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived. + * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly. + * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly. + * max_position_embeddings (`int`): The maximum length of the positional embeddings. + * rope_scaling (`dict[str, float | int]`): The standard RoPE scaling parameters, from which the following + keys will be accessed: + * `attention_factor` (`float`, *optional*): The scaling factor to be applied to the computed cos/sin. + If None, the value is inferred from `factor`, `mscale`, and `mscale_all_dim` as avaialble. + * `beta_fast` (`float`, *optional*, defaults to 32): Parameter to set the boundary for extrapolation + (only) in the linear ramp function. + * `beta_slow` (`float`, *optional*, defaults to 1): Parameter to set the boundary for interpolation + (only) in the linear ramp function. + * `factor` (`float`, *optional*): The scaling factor applied when interpolating the position IDs to + extend the possible context length. Additionally, if `attention_factor` is None, the log of this + value is used to compute a value for `attention_factor`, possibly in conjunciton with `mscale` and + `mscale_all_dim`, if provided. + * `mscale` (`float`, *optional*): If `attention_factor` is None and both `mscale` and + `mscale_all_dim` are provided, `mscale` acts scalar augmenting `log(factor)` when computing the + numerator for the inferred value of `attention_factor`. If not provided, `attention_factor` will be + calculated based on `factor` only. + * `mscale_all_dim` (`float`, *optional*): If `attention_factor` is None and both `mscale` and + `mscale_all_dim` are provided, `mscale_all_dim` acts scalar augmenting `log(factor)` when computing + the denominator for the inferred value of `attention_factor`. If not provided, `attention_factor` + will be calculated based on `factor` only. + * `original_max_position_embeddings` (`int`, *optional*): The original max position embeddings used + during pretraining. If not provided, the function falls back to `max_position_embeddings`. + * `truncate` (`bool`, *optional*): Whether to truncate the correction range. + + Additionally, this function will make use of the following properties if they are found in the config: + + * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be + derived as hidden_size // num_attention_heads. + * partial_rotary_factor (`float`, *optional*, defaults to 1.0): If less than 1.0, inverse frequencies + will be returned for the first fraction of the head_dim. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + + base = config.rope_theta + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + factor = config.rope_scaling["factor"] + attention_factor = config.rope_scaling.get("attention_factor") + mscale = config.rope_scaling.get("mscale") + mscale_all_dim = config.rope_scaling.get("mscale_all_dim") + original_max_position_embeddings = ( + config.rope_scaling.get("original_max_position_embeddings") or config.max_position_embeddings + ) + + def get_mscale(scale, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + # Sets the attention factor as suggested in the paper + if attention_factor is None: + if mscale and mscale_all_dim: + attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim)) + else: + attention_factor = get_mscale(factor) + + # Optional config options + # beta_fast/beta_slow: as suggested in the paper, default to 32 and 1 respectively + beta_fast = config.rope_scaling.get("beta_fast") or 32 + beta_slow = config.rope_scaling.get("beta_slow") or 1 + + # Compute the inverse frequencies + def find_correction_dim(num_rotations, dim, base, max_position_embeddings): + """Inverse dimension formula to find the dimension based on the number of rotations""" + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, truncate): + """Find dimension range bounds based on rotations""" + low = find_correction_dim(low_rot, dim, base, max_position_embeddings) + high = find_correction_dim(high_rot, dim, base, max_position_embeddings) + if truncate: + low = math.floor(low) + high = math.ceil(high) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs + # to expand the possible context length. In other words, interpolation = apply scaling factor. + pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + truncate = config.rope_scaling.get("truncate", True) + low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate) + + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor + ) + return inv_freq, attention_factor + + +def _compute_longrope_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None +) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with LongRoPE scaling. Please refer to the + [original implementation](https://github.com/microsoft/LongRoPE) + + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. This function assumes that the config will provide at least the following + properties: + + * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived. + * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly. + * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly. + * max_position_embeddings (`int`): The maximum length of the positional embeddings. + * original_max_position_embeddings (`int`, *optional*): The original max position embeddings used during + pretraining. If not provided, defaults to `max_position_embeddings`. + * rope_scaling (`dict[str, float]`): The standard RoPE scaling parameters, from which the following keys + will be accessed: + * `attention_factor` (`float`, *optional*): The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, inferred from + the value of `factor`. + * `factor` (`float`, *optional*): The scaling factor to apply to the RoPE embeddings. If both + `max_position_embeddings` and `original_max_position_embeddings` are provided, this value will be + overridden s the ratio between those values. + * `long_factor` (`float`, *optional*): The scale factor applied when computing the inverse + frequencies if `seq_len` is provided and greater than `original_max_position_embeddings`. + * `short_factor` (`float`, *optional*): The scale factor applied when computing the inverse + frequencies if `seq_len` is None or less-than-or-equal-to `original_max_position_embeddings`. + + Additionally, this function will make use of the following properties if they are found in the config: + + * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be + derived as hidden_size // num_attention_heads. + * partial_rotary_factor (`float`, *optional*, defaults to 1.0): If less than 1.0, inverse frequencies + will be returned for the first fraction of the head_dim. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. + + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling + base = config.rope_theta + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + long_factor = config.rope_scaling["long_factor"] + short_factor = config.rope_scaling["short_factor"] + factor = config.rope_scaling.get("factor") + attention_factor = config.rope_scaling.get("attention_factor") + + # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a + # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two + # values to compute the default attention scaling factor, instead of using `factor`. + if original_max_position_embeddings := getattr(config, "original_max_position_embeddings", None): + factor = config.max_position_embeddings / original_max_position_embeddings + else: + original_max_position_embeddings = config.max_position_embeddings + + # Sets the attention factor as suggested in the paper + if attention_factor is None: + if factor <= 1.0: + attention_factor = 1.0 + else: + attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings)) + + # Compute the inverse frequencies -- scaled based on the target sequence length + if seq_len and seq_len > original_max_position_embeddings: + ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device) + else: + ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device) + inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim + inv_freq = 1.0 / (ext_factors * base**inv_freq_shape) + + return inv_freq, attention_factor + + +def _compute_llama3_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None +) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies for llama 3.1. + + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. This function assumes that the config will provide at least the following + properties: + + * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived. + * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly. + * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly. + * rope_scaling (`dict[str, float | int]`): The standard RoPE scaling parameters, from which the following + keys will be accessed: + * `factor` (`float`, *optional*): The scaling factor applied to the inverse frequencies when 1) the + wavelength is greater than `low_freq_wavelen` prior to smoothing, and 2) to all inverse frequencies + during smoothing. + * `high_freq_factor` (`float`): The scale factor used to compute `high_freq_wavelen` and + the value for the denominator of the smoothing factor prior to the `low_freq_factor` shift. + * `low_freq_factor` (`float`): The scale factor used to compute `low_freq_wavelen` and + the shift applied to the numerator and denominator of the smoothing factor. + frequencies if `seq_len` is None or less-than-or-equal-to `original_max_position_embeddings`. + * `original_max_position_embeddings` (`int`): The original max position embeddings used + during pretraining. If not provided, the function falls back to `max_position_embeddings`. + + Additionally, this function will make use of the following properties if they are found in the config: + + * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be + derived as hidden_size // num_attention_heads. + * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for + the first fraction of the head_dim. Defaults to 1.0. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # Gets the default RoPE parameters + inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len) + + factor = config.rope_scaling["factor"] # `8` in the original implementation + low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation + high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation + old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + wavelen = 2 * math.pi / inv_freq + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by factor + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + + return inv_freq_llama, attention_factor + + +# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters +# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE +# parameterizations, as long as the callable has the same signature. +ROPE_INIT_FUNCTIONS = { + "default": _compute_default_rope_parameters, + "linear": _compute_linear_scaling_rope_parameters, + "dynamic": _compute_dynamic_ntk_parameters, + "yarn": _compute_yarn_parameters, + "longrope": _compute_longrope_parameters, + "llama3": _compute_llama3_parameters, +} + + +def _check_received_keys( + rope_type: str, + received_keys: set, + required_keys: set, + optional_keys: Optional[set] = None, + ignore_keys: Optional[set] = None, +): + """Compare the received keys in `config.rope_scaling` against the expected and optional keys""" + # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present + if "type" in received_keys: + received_keys -= {"type"} + required_keys.add("rope_type") + + # Some models need to store model-specific keys, and we don't want to throw warning at them + if ignore_keys is not None: + received_keys -= ignore_keys + + missing_keys = required_keys - received_keys + if missing_keys: + raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}") + + if optional_keys is not None: + unused_keys = received_keys - required_keys - optional_keys + else: + unused_keys = received_keys - required_keys + if unused_keys: + logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") + + +def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) + + +def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + +def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor"} + # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` + optional_keys = {"original_max_position_embeddings"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + +def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor"} + optional_keys = { + "attention_factor", + "beta_fast", + "beta_slow", + "original_max_position_embeddings", + "mscale", + "mscale_all_dim", + "truncate", + } + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + attention_factor = rope_scaling.get("attention_factor") + if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + beta_fast = rope_scaling.get("beta_fast") + if beta_fast is not None and not isinstance(beta_fast, float): + logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + beta_slow = rope_scaling.get("beta_slow") + if beta_slow is not None and not isinstance(beta_slow, float): + logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + + if (beta_fast or 32) < (beta_slow or 1): + logger.warning( + f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " + f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" + ) + + # Models should set `config.rope_scaling["original_max_position_embeddings"]` to their original (pre-yarn) context + # length, with `config.max_position_embeddings` corresponding to their post-yarn context length. + # However, for BC purposes, we allow the former to be unset. + original_max_position_embeddings = config.rope_scaling.get("original_max_position_embeddings") + if original_max_position_embeddings is not None: + # Double-check: `factor` should be the ratio between the pre-yarn and post-yarn context lengths. + implicit_factor = config.max_position_embeddings / original_max_position_embeddings + if implicit_factor != factor: + logger.warning_once( + f"The explicitly set RoPE scaling factor (config.rope_scaling['factor'] = {factor}) does not match " + "the ratio implicitly set by other parameters (implicit factor = " + "post-yarn context length / pre-yarn context length = " + "config.max_position_embeddings / config.rope_scaling['original_max_position_embeddings'] = " + f"{implicit_factor}). Using the explicit factor ({factor}) in YaRN. This may cause unexpected " + "behaviour in model usage, please correct the 'max_position_embeddings' fields in the model config." + ) + # No `config.rope_scaling["original_max_position_embeddings"]`. Is `config.max_position_embeddings` the + # pre-yarn or the post-yarn context length? + # BC: we assume it is the pre-yarn context length. + else: + logger.warning_once( + "config.rope_scaling['original_max_position_embeddings'], the pre-yarn context length, is unset. We will " + "**assume** config.max_position_embeddings holds the pre-yarn context length. Some use cases may expect " + "config.max_position_embeddings to hold the post-yarn context length (pre-yarn context length * " + "factor) -- we recommend updating both fields for optimal downstream model usage." + ) + + +def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "short_factor", "long_factor"} + # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` + optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) + + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + short_factor = rope_scaling.get("short_factor") + if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor): + logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}") + if len(short_factor) != dim // 2: + logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}") + + long_factor = rope_scaling.get("long_factor") + if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor): + logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}") + if len(long_factor) != dim // 2: + logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}") + + # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over + # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is + # unique to longrope (= undesirable) + if hasattr(config, "original_max_position_embeddings"): + logger.warning_once( + "This model has set a `original_max_position_embeddings` field, to be used together with " + "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`" + "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, " + "as it is compatible with most model architectures." + ) + else: + factor = rope_scaling.get("factor") + if factor is None: + logger.warning("Missing required keys in `rope_scaling`: 'factor'") + elif not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + attention_factor = rope_scaling.get("attention_factor") + if attention_factor is not None: + if not isinstance(attention_factor, float) or attention_factor < 0.0: + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + + +def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + if low_freq_factor is None or not isinstance(low_freq_factor, float): + logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") + if high_freq_factor is None or not isinstance(high_freq_factor, float): + logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") + if high_freq_factor <= low_freq_factor: + logger.warning( + "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" + f"{high_freq_factor} and low_freq_factor={low_freq_factor}" + ) + + original_max_position_embeddings = rope_scaling["original_max_position_embeddings"] + if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int): + logger.warning( + "`rope_scaling`'s original_max_position_embeddings field must be an integer, got " + f"{original_max_position_embeddings}" + ) + if original_max_position_embeddings >= config.max_position_embeddings: + logger.warning( + "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got " + f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}" + ) + + +# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types. +ROPE_VALIDATION_FUNCTIONS = { + "default": _validate_default_rope_parameters, + "linear": _validate_linear_scaling_rope_parameters, + "dynamic": _validate_dynamic_scaling_rope_parameters, + "yarn": _validate_yarn_parameters, + "longrope": _validate_longrope_parameters, + "llama3": _validate_llama3_parameters, +} + + +def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None): + """ + Validate the RoPE config arguments, given a `PretrainedConfig` object + """ + rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig` + if rope_scaling is None: + return + + # BC: "rope_type" was originally "type" + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) + validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) + if validation_fn is not None: + validation_fn(config, ignore_keys=ignore_keys) + else: + logger.warning( + f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_tf_outputs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_tf_outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..c7491b67f9aebb93b95632a5a7db2fd85cf5b4c7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_tf_outputs.py @@ -0,0 +1,990 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from dataclasses import dataclass + +import tensorflow as tf + +from .utils import ModelOutput + + +@dataclass +class TFBaseModelOutput(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(tf.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFBaseModelOutputWithNoAttention(ModelOutput): + """ + Base class for model's outputs, with potential hidden states. + + Args: + last_hidden_state (`tf.Tensor` shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor, ...] | None = None + + +@dataclass +class TFBaseModelOutputWithPooling(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) further processed by a + Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence + prediction (classification) objective during pretraining. + + This output is usually *not* a good summary of the semantic content of the input, you're often better with + averaging or pooling the sequence of hidden-states for the whole input sequence. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: tf.Tensor | None = None + pooler_output: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFBaseModelOutputWithPoolingAndNoAttention(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state after a pooling operation on the spatial dimensions. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: tf.Tensor | None = None + pooler_output: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor, ...] | None = None + + +@dataclass +class TFBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) further processed by a + Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence + prediction (classification) objective during pretraining. + + This output is usually *not* a good summary of the semantic content of the input, you're often better with + averaging or pooling the sequence of hidden-states for the whole input sequence. + past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: tf.Tensor | None = None + pooler_output: tf.Tensor | None = None + past_key_values: list[tf.Tensor] | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + cross_attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFBaseModelOutputWithPast(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: tf.Tensor | None = None + past_key_values: list[tf.Tensor] | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFBaseModelOutputWithCrossAttentions(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(tf.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + cross_attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFBaseModelOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(tf.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: tf.Tensor | None = None + past_key_values: list[tf.Tensor] | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + cross_attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFSeq2SeqModelOutput(ModelOutput): + """ + Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + last_hidden_state: tf.Tensor | None = None + past_key_values: list[tf.Tensor] | None = None + decoder_hidden_states: tuple[tf.Tensor] | None = None + decoder_attentions: tuple[tf.Tensor] | None = None + cross_attentions: tuple[tf.Tensor] | None = None + encoder_last_hidden_state: tf.Tensor | None = None + encoder_hidden_states: tuple[tf.Tensor] | None = None + encoder_attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFCausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFCausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + past_key_values: list[tf.Tensor] | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFCausalLMOutputWithCrossAttentions(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + past_key_values: list[tf.Tensor] | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + cross_attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFMaskedLMOutput(ModelOutput): + """ + Base class for masked language models outputs. + + Args: + loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided): + Masked language modeling (MLM) loss. + logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFSeq2SeqLMOutput(ModelOutput): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided): + Language modeling loss. + logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + past_key_values: list[tf.Tensor] | None = None + decoder_hidden_states: tuple[tf.Tensor] | None = None + decoder_attentions: tuple[tf.Tensor] | None = None + cross_attentions: tuple[tf.Tensor] | None = None + encoder_last_hidden_state: tf.Tensor | None = None + encoder_hidden_states: tuple[tf.Tensor] | None = None + encoder_attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFNextSentencePredictorOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `next_sentence_label` is provided): + Next sentence prediction loss. + logits (`tf.Tensor` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFSequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFSeq2SeqSequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of sequence-to-sequence sentence classification models. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `label` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)` + encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + past_key_values: list[tf.Tensor] | None = None + decoder_hidden_states: tuple[tf.Tensor] | None = None + decoder_attentions: tuple[tf.Tensor] | None = None + cross_attentions: tuple[tf.Tensor] | None = None + encoder_last_hidden_state: tf.Tensor | None = None + encoder_hidden_states: tuple[tf.Tensor] | None = None + encoder_attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFSemanticSegmenterOutput(ModelOutput): + """ + Base class for outputs of semantic segmentation models. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): + Classification scores for each pixel. + + + + The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is + to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the + original image size as post-processing. You should always check your logits shape and resize as needed. + + + + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, patch_size, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFSemanticSegmenterOutputWithNoAttention(ModelOutput): + """ + Base class for outputs of semantic segmentation models that do not output attention scores. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): + Classification scores for each pixel. + + + + The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is + to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the + original image size as post-processing. You should always check your logits shape and resize as needed. + + + + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, patch_size, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor] | None = None + + +@dataclass +class TFImageClassifierOutput(ModelOutput): + """ + Base class for outputs of image classification models. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called + feature maps) of the model at the output of each stage. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFMultipleChoiceModelOutput(ModelOutput): + """ + Base class for outputs of multiple choice models. + + Args: + loss (`tf.Tensor` of shape *(batch_size, )*, *optional*, returned when `labels` is provided): + Classification loss. + logits (`tf.Tensor` of shape `(batch_size, num_choices)`): + *num_choices* is the second dimension of the input tensors. (see *input_ids* above). + + Classification scores (before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFTokenClassifierOutput(ModelOutput): + """ + Base class for outputs of token classification models. + + Args: + loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of unmasked labels, returned when `labels` is provided) : + Classification loss. + logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Classification scores (before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFQuestionAnsweringModelOutput(ModelOutput): + """ + Base class for outputs of question answering models. + + Args: + loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `start_positions` and `end_positions` are provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + start_logits: tf.Tensor | None = None + end_logits: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput): + """ + Base class for outputs of sequence-to-sequence question answering models. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: tf.Tensor | None = None + start_logits: tf.Tensor | None = None + end_logits: tf.Tensor | None = None + past_key_values: list[tf.Tensor] | None = None + decoder_hidden_states: tuple[tf.Tensor] | None = None + decoder_attentions: tuple[tf.Tensor] | None = None + encoder_last_hidden_state: tf.Tensor | None = None + encoder_hidden_states: tuple[tf.Tensor] | None = None + encoder_attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFSequenceClassifierOutputWithPast(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + past_key_values: list[tf.Tensor] | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + + +@dataclass +class TFImageClassifierOutputWithNoAttention(ModelOutput): + """ + Base class for outputs of image classification models. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also called + feature maps) of the model at the output of each stage. + """ + + loss: tf.Tensor | None = None + logits: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor, ...] | None = None + + +@dataclass +class TFMaskedImageModelingOutput(ModelOutput): + """ + Base class for outputs of masked image completion / in-painting models. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): + Reconstruction loss. + reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed / completed images. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called + feature maps) of the model at the output of each stage. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + reconstruction: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + + @property + def logits(self): + warnings.warn( + "logits attribute is deprecated and will be removed in version 5 of Transformers." + " Please use the reconstruction attribute to retrieve the final output instead.", + FutureWarning, + ) + return self.reconstruction diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_tf_pytorch_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_tf_pytorch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8f688af7be36439311465d019de36c20c6aaae77 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_tf_pytorch_utils.py @@ -0,0 +1,676 @@ +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch - TF 2.0 general utilities.""" + +import os +import re + +import numpy + +from .utils import ( + ExplicitEnum, + check_torch_load_is_safe, + expand_dims, + is_numpy_array, + is_safetensors_available, + is_torch_tensor, + logging, + reshape, + squeeze, + tensor_size, +) +from .utils import transpose as transpose_func + + +if is_safetensors_available(): + from safetensors import safe_open + + +logger = logging.get_logger(__name__) + + +class TransposeType(ExplicitEnum): + """ + Possible ... + """ + + NO = "no" + SIMPLE = "simple" + CONV1D = "conv1d" + CONV2D = "conv2d" + + +def convert_tf_weight_name_to_pt_weight_name( + tf_name, start_prefix_to_remove="", tf_weight_shape=None, name_scope=None +): + """ + Convert a TF 2.0 model variable name in a pytorch model weight name. + + Conventions for TF2.0 scopes -> PyTorch attribute names conversions: + + - '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch) + - '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList) + + return tuple with: + + - pytorch model weight name + - transpose: `TransposeType` member indicating whether and how TF2.0 and PyTorch weights matrices should be + transposed with regards to each other + """ + if name_scope is not None: + if not tf_name.startswith(name_scope) and "final_logits_bias" not in tf_name: + raise ValueError( + f"Weight name {tf_name} does not start with name_scope {name_scope}. This is an internal error " + "in Transformers, so (unless you were doing something really evil) please open an issue to report it!" + ) + tf_name = tf_name[len(name_scope) :] + tf_name = tf_name.lstrip("/") + tf_name = tf_name.replace(":0", "") # device ids + if (len(tf_name) > 2048 and "___" in tf_name) or tf_name.count("___") > 10: + # ReDOS check + raise ValueError("TF variable name is too long or contains too many ___ separators: " + tf_name) + tf_name = re.sub( + r"/[^/]*___([^/]*)/", r"/\1/", tf_name + ) # '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch) + tf_name = tf_name.replace( + "_._", "/" + ) # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList) + tf_name = re.sub(r"//+", "/", tf_name) # Remove empty levels at the end + tf_name = tf_name.split("/") # Convert from TF2.0 '/' separators to PyTorch '.' separators + # Some weights have a single name without "/" such as final_logits_bias in BART + if len(tf_name) > 1: + tf_name = tf_name[1:] # Remove level zero + + tf_weight_shape = list(tf_weight_shape) + + # When should we transpose the weights + if tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 4: + transpose = TransposeType.CONV2D + elif tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 3: + transpose = TransposeType.CONV1D + elif bool( + tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"] + or "emb_projs" in tf_name + or "out_projs" in tf_name + ): + transpose = TransposeType.SIMPLE + else: + transpose = TransposeType.NO + + # Convert standard TF2.0 names in PyTorch names + if tf_name[-1] == "kernel" or tf_name[-1] == "embeddings" or tf_name[-1] == "gamma": + tf_name[-1] = "weight" + if tf_name[-1] == "beta": + tf_name[-1] = "bias" + + # The SeparableConv1D TF layer contains two weights that are translated to PyTorch Conv1D here + if tf_name[-1] == "pointwise_kernel" or tf_name[-1] == "depthwise_kernel": + tf_name[-1] = tf_name[-1].replace("_kernel", ".weight") + + # Remove prefix if needed + tf_name = ".".join(tf_name) + if start_prefix_to_remove: + tf_name = tf_name.replace(start_prefix_to_remove, "", 1) + + return tf_name, transpose + + +def apply_transpose(transpose: TransposeType, weight, match_shape=None, pt_to_tf=True): + """ + Apply a transpose to some weight then tries to reshape the weight to the same shape as a given shape, all in a + framework agnostic way. + """ + if transpose is TransposeType.CONV2D: + # Conv2D weight: + # PT: (num_out_channel, num_in_channel, kernel[0], kernel[1]) + # -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel) + axes = (2, 3, 1, 0) if pt_to_tf else (3, 2, 0, 1) + weight = transpose_func(weight, axes=axes) + elif transpose is TransposeType.CONV1D: + # Conv1D weight: + # PT: (num_out_channel, num_in_channel, kernel) + # -> TF: (kernel, num_in_channel, num_out_channel) + weight = transpose_func(weight, axes=(2, 1, 0)) + elif transpose is TransposeType.SIMPLE: + weight = transpose_func(weight) + + if match_shape is None: + return weight + + if len(match_shape) < len(weight.shape): + weight = squeeze(weight) + elif len(match_shape) > len(weight.shape): + weight = expand_dims(weight, axis=0) + + if list(match_shape) != list(weight.shape): + try: + weight = reshape(weight, match_shape) + except AssertionError as e: + e.args += (match_shape, match_shape) + raise e + + return weight + + +##################### +# PyTorch => TF 2.0 # +##################### + + +def load_pytorch_checkpoint_in_tf2_model( + tf_model, + pytorch_checkpoint_path, + tf_inputs=None, + allow_missing_keys=False, + output_loading_info=False, + _prefix=None, + tf_to_pt_weight_rename=None, +): + """Load pytorch checkpoints in a TF 2.0 model""" + try: + import tensorflow as tf # noqa: F401 + import torch # noqa: F401 + from safetensors.torch import load_file as safe_load_file # noqa: F401 + except ImportError: + logger.error( + "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see " + "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." + ) + raise + + # Treats a single file as a collection of shards with 1 shard. + if isinstance(pytorch_checkpoint_path, str): + pytorch_checkpoint_path = [pytorch_checkpoint_path] + + # Loads all shards into a single state dictionary + pt_state_dict = {} + for path in pytorch_checkpoint_path: + pt_path = os.path.abspath(path) + logger.info(f"Loading PyTorch weights from {pt_path}") + if pt_path.endswith(".safetensors"): + state_dict = safe_load_file(pt_path) + else: + check_torch_load_is_safe() + state_dict = torch.load(pt_path, map_location="cpu", weights_only=True) + + pt_state_dict.update(state_dict) + + logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters") + + return load_pytorch_weights_in_tf2_model( + tf_model, + pt_state_dict, + tf_inputs=tf_inputs, + allow_missing_keys=allow_missing_keys, + output_loading_info=output_loading_info, + _prefix=_prefix, + tf_to_pt_weight_rename=tf_to_pt_weight_rename, + ) + + +def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_missing_keys=False): + """Load pytorch checkpoints in a TF 2.0 model""" + pt_state_dict = pt_model.state_dict() + + return load_pytorch_weights_in_tf2_model( + tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys + ) + + +def load_pytorch_weights_in_tf2_model( + tf_model, + pt_state_dict, + tf_inputs=None, + allow_missing_keys=False, + output_loading_info=False, + _prefix=None, + tf_to_pt_weight_rename=None, +): + """Load pytorch state_dict in a TF 2.0 model.""" + try: + import tensorflow as tf # noqa: F401 + import torch # noqa: F401 + except ImportError: + logger.error( + "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see " + "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." + ) + raise + + # Numpy doesn't understand bfloat16, so upcast to a dtype that doesn't lose precision + pt_state_dict = { + k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items() + } + return load_pytorch_state_dict_in_tf2_model( + tf_model, + pt_state_dict, + tf_inputs=tf_inputs, + allow_missing_keys=allow_missing_keys, + output_loading_info=output_loading_info, + _prefix=_prefix, + tf_to_pt_weight_rename=tf_to_pt_weight_rename, + ) + + +def _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name): + if len(unexpected_keys) > 0: + logger.warning( + "Some weights of the PyTorch model were not used when initializing the TF 2.0 model" + f" {class_name}: {unexpected_keys}\n- This IS expected if you are initializing" + f" {class_name} from a PyTorch model trained on another task or with another architecture" + " (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS" + f" NOT expected if you are initializing {class_name} from a PyTorch model that you expect" + " to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.warning(f"All PyTorch model weights were used when initializing {class_name}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights or buffers of the TF 2.0 model {class_name} were not initialized from the" + f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a" + " down-stream task to be able to use it for predictions and inference." + ) + else: + logger.warning( + f"All the weights of {class_name} were initialized from the PyTorch model.\n" + "If your task is similar to the task the model of the checkpoint was trained on, " + f"you can already use {class_name} for predictions without further training." + ) + + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {class_name} were not initialized from the model checkpoint" + f" are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" + " to use it for predictions and inference." + ) + + +def load_pytorch_state_dict_in_tf2_model( + tf_model, + pt_state_dict, + tf_inputs=None, + allow_missing_keys=False, + output_loading_info=False, + _prefix=None, + tf_to_pt_weight_rename=None, + ignore_mismatched_sizes=False, + skip_logger_warnings=False, +): + """Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading + safetensors archive created with the safe_open() function.""" + import tensorflow as tf + + if tf_inputs is None: + tf_inputs = tf_model.dummy_inputs + + if _prefix is None: + _prefix = "" + if tf_inputs: + with tf.name_scope(_prefix): + tf_model(tf_inputs, training=False) # Make sure model is built + # Convert old format to new format if needed from a PyTorch state_dict + tf_keys_to_pt_keys = {} + for key in pt_state_dict: + new_key = None + if "gamma" in key: + new_key = key.replace("gamma", "weight") + if "beta" in key: + new_key = key.replace("beta", "bias") + if "running_var" in key: + new_key = key.replace("running_var", "moving_variance") + if "running_mean" in key: + new_key = key.replace("running_mean", "moving_mean") + + # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030 + key_components = key.split(".") + name = None + if key_components[-3::2] == ["parametrizations", "original0"]: + name = key_components[-2] + "_g" + elif key_components[-3::2] == ["parametrizations", "original1"]: + name = key_components[-2] + "_v" + if name is not None: + key_components = key_components[:-3] + [name] + new_key = ".".join(key_components) + + if new_key is None: + new_key = key + tf_keys_to_pt_keys[new_key] = key + + # Matt: All TF models store the actual model stem in a MainLayer class, including the base model. + # In PT, the derived models (with heads) use the base model class as the stem instead, + # and there is no MainLayer class. This means that TF base classes have one + # extra layer in their weight names, corresponding to the MainLayer class. This code block compensates for that. + start_prefix_to_remove = "" + if not any(s.startswith(tf_model.base_model_prefix) for s in tf_keys_to_pt_keys): + start_prefix_to_remove = tf_model.base_model_prefix + "." + + symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights + tf_loaded_numel = 0 + all_pytorch_weights = set(tf_keys_to_pt_keys.keys()) + missing_keys = [] + mismatched_keys = [] + is_safetensor_archive = hasattr(pt_state_dict, "get_tensor") + for symbolic_weight in symbolic_weights: + sw_name = symbolic_weight.name + name, transpose = convert_tf_weight_name_to_pt_weight_name( + sw_name, + start_prefix_to_remove=start_prefix_to_remove, + tf_weight_shape=symbolic_weight.shape, + name_scope=_prefix, + ) + if tf_to_pt_weight_rename is not None: + aliases = tf_to_pt_weight_rename(name) # Is a tuple to account for possible name aliasing + for alias in aliases: # The aliases are in priority order, take the first one that matches + if alias in tf_keys_to_pt_keys: + name = alias + break + else: + # If none of the aliases match, just use the first one (it'll be reported as missing) + name = aliases[0] + + # Find associated numpy array in pytorch model state dict + if name not in tf_keys_to_pt_keys: + if allow_missing_keys: + missing_keys.append(name) + continue + elif tf_model._keys_to_ignore_on_load_missing is not None: + # authorized missing keys don't have to be loaded + if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing): + continue + raise AttributeError(f"{name} not found in PyTorch model") + state_dict_name = tf_keys_to_pt_keys[name] + if is_safetensor_archive: + array = pt_state_dict.get_tensor(state_dict_name) + else: + array = pt_state_dict[state_dict_name] + try: + array = apply_transpose(transpose, array, symbolic_weight.shape) + except tf.errors.InvalidArgumentError as e: + if not ignore_mismatched_sizes: + error_msg = str(e) + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise tf.errors.InvalidArgumentError(error_msg) + else: + mismatched_keys.append((name, array.shape, symbolic_weight.shape)) + continue + + tf_loaded_numel += tensor_size(array) + + symbolic_weight.assign(tf.cast(array, symbolic_weight.dtype)) + del array # Immediately free memory to keep peak usage as low as possible + all_pytorch_weights.discard(name) + + logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.") + + unexpected_keys = list(all_pytorch_weights) + + if tf_model._keys_to_ignore_on_load_missing is not None: + for pat in tf_model._keys_to_ignore_on_load_missing: + missing_keys = [k for k in missing_keys if re.search(pat, k) is None] + if tf_model._keys_to_ignore_on_load_unexpected is not None: + for pat in tf_model._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + if not skip_logger_warnings: + _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__) + + if output_loading_info: + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + } + return tf_model, loading_info + + return tf_model + + +def load_sharded_pytorch_safetensors_in_tf2_model( + tf_model, + safetensors_shards, + tf_inputs=None, + allow_missing_keys=False, + output_loading_info=False, + _prefix=None, + tf_to_pt_weight_rename=None, + ignore_mismatched_sizes=False, +): + all_loading_infos = [] + for shard in safetensors_shards: + with safe_open(shard, framework="tf") as safetensors_archive: + tf_model, loading_info = load_pytorch_state_dict_in_tf2_model( + tf_model, + safetensors_archive, + tf_inputs=tf_inputs, + allow_missing_keys=allow_missing_keys, + output_loading_info=True, + _prefix=_prefix, + tf_to_pt_weight_rename=tf_to_pt_weight_rename, + ignore_mismatched_sizes=ignore_mismatched_sizes, + skip_logger_warnings=True, # We will emit merged warnings at the end + ) + all_loading_infos.append(loading_info) + # Now we just need to merge the loading info + # Keys are missing only if they're missing in *every* shard + missing_keys = sorted(set.intersection(*[set(info["missing_keys"]) for info in all_loading_infos])) + # Keys are unexpected/mismatched if they're unexpected/mismatched in *any* shard + unexpected_keys = sum([info["unexpected_keys"] for info in all_loading_infos], []) + mismatched_keys = sum([info["mismatched_keys"] for info in all_loading_infos], []) + + _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__) + + if output_loading_info: + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + } + return tf_model, loading_info + + return tf_model + + +##################### +# TF 2.0 => PyTorch # +##################### + + +def load_tf2_checkpoint_in_pytorch_model( + pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False, output_loading_info=False +): + """ + Load TF 2.0 HDF5 checkpoint in a PyTorch model We use HDF5 to easily do transfer learning (see + https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357). + """ + try: + import tensorflow as tf # noqa: F401 + import torch # noqa: F401 + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " + "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." + ) + raise + + import transformers + + from .modeling_tf_utils import load_tf_weights + + logger.info(f"Loading TensorFlow weights from {tf_checkpoint_path}") + + # Instantiate and load the associated TF 2.0 model + tf_model_class_name = "TF" + pt_model.__class__.__name__ # Add "TF" at the beginning + tf_model_class = getattr(transformers, tf_model_class_name) + tf_model = tf_model_class(pt_model.config) + + if tf_inputs is None: + tf_inputs = tf_model.dummy_inputs + + if tf_inputs is not None: + tf_model(tf_inputs, training=False) # Make sure model is built + + load_tf_weights(tf_model, tf_checkpoint_path) + + return load_tf2_model_in_pytorch_model( + pt_model, tf_model, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info + ) + + +def load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=False, output_loading_info=False): + """Load TF 2.0 model in a pytorch model""" + weights = tf_model.weights + + return load_tf2_weights_in_pytorch_model( + pt_model, weights, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info + ) + + +def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=False, output_loading_info=False): + """Load TF2.0 symbolic weights in a PyTorch model""" + try: + import tensorflow as tf # noqa: F401 + import torch # noqa: F401 + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " + "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." + ) + raise + + tf_state_dict = {tf_weight.name: tf_weight.numpy() for tf_weight in tf_weights} + return load_tf2_state_dict_in_pytorch_model( + pt_model, tf_state_dict, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info + ) + + +def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_keys=False, output_loading_info=False): + import torch + + new_pt_params_dict = {} + current_pt_params_dict = dict(pt_model.named_parameters()) + + # Make sure we are able to load PyTorch base models as well as derived models (with heads) + # TF models always have a prefix, some of PyTorch models (base ones) don't + start_prefix_to_remove = "" + if not any(s.startswith(pt_model.base_model_prefix) for s in current_pt_params_dict): + start_prefix_to_remove = pt_model.base_model_prefix + "." + + # Build a map from potential PyTorch weight names to TF 2.0 Variables + tf_weights_map = {} + for name, tf_weight in tf_state_dict.items(): + pt_name, transpose = convert_tf_weight_name_to_pt_weight_name( + name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape + ) + tf_weights_map[pt_name] = (tf_weight, transpose) + + all_tf_weights = set(tf_weights_map.keys()) + loaded_pt_weights_data_ptr = {} + missing_keys_pt = [] + for pt_weight_name, pt_weight in current_pt_params_dict.items(): + # Handle PyTorch shared weight not duplicated in TF 2.0 + if pt_weight.data_ptr() in loaded_pt_weights_data_ptr and pt_weight.data_ptr() != 0: + new_pt_params_dict[pt_weight_name] = loaded_pt_weights_data_ptr[pt_weight.data_ptr()] + continue + + pt_weight_name_to_check = pt_weight_name + # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030 + key_components = pt_weight_name.split(".") + name = None + if key_components[-3::2] == ["parametrizations", "original0"]: + name = key_components[-2] + "_g" + elif key_components[-3::2] == ["parametrizations", "original1"]: + name = key_components[-2] + "_v" + if name is not None: + key_components = key_components[:-3] + [name] + pt_weight_name_to_check = ".".join(key_components) + + # Find associated numpy array in pytorch model state dict + if pt_weight_name_to_check not in tf_weights_map: + if allow_missing_keys: + missing_keys_pt.append(pt_weight_name) + continue + + raise AttributeError(f"{pt_weight_name} not found in TF 2.0 model") + + array, transpose = tf_weights_map[pt_weight_name_to_check] + + array = apply_transpose(transpose, array, pt_weight.shape, pt_to_tf=False) + + if numpy.isscalar(array): + array = numpy.array(array) + if not is_torch_tensor(array) and not is_numpy_array(array): + array = array.numpy() + if is_numpy_array(array): + # Convert to torch tensor + array = torch.from_numpy(array) + + new_pt_params_dict[pt_weight_name] = array + loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = array + all_tf_weights.discard(pt_weight_name) + + missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False) + missing_keys += missing_keys_pt + + # Some models may have keys that are not in the state by design, removing them before needlessly warning + # the user. + if pt_model._keys_to_ignore_on_load_missing is not None: + for pat in pt_model._keys_to_ignore_on_load_missing: + missing_keys = [k for k in missing_keys if re.search(pat, k) is None] + + if pt_model._keys_to_ignore_on_load_unexpected is not None: + for pat in pt_model._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warning( + "Some weights of the TF 2.0 model were not used when initializing the PyTorch model" + f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing" + f" {pt_model.__class__.__name__} from a TF 2.0 model trained on another task or with another architecture" + " (e.g. initializing a BertForSequenceClassification model from a TFBertForPreTraining model).\n- This IS" + f" NOT expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model that you expect" + " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a" + " TFBertForSequenceClassification model)." + ) + else: + logger.warning(f"All TF 2.0 model weights were used when initializing {pt_model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {pt_model.__class__.__name__} were not initialized from the TF 2.0 model and are newly" + f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to" + " use it for predictions and inference." + ) + else: + logger.warning( + f"All the weights of {pt_model.__class__.__name__} were initialized from the TF 2.0 model.\n" + "If your task is similar to the task the model of the checkpoint was trained on, " + f"you can already use {pt_model.__class__.__name__} for predictions without further training." + ) + + logger.info(f"Weights or buffers not loaded from TF 2.0 model: {all_tf_weights}") + + if output_loading_info: + loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys} + return pt_model, loading_info + + return pt_model diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_tf_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_tf_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c7bb80656d1b8596e9f3313b6f768c6d2251099d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/modeling_tf_utils.py @@ -0,0 +1,3529 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF general model utils.""" + +from __future__ import annotations + +import functools +import gc +import inspect +import json +import os +import pickle +import re +import warnings +from collections.abc import Mapping +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Union + +import h5py +import numpy as np +import tensorflow as tf +from packaging.version import parse + +from . import DataCollatorWithPadding, DefaultDataCollator +from .activations_tf import get_tf_activation +from .configuration_utils import PretrainedConfig +from .dynamic_module_utils import custom_object_save +from .generation import GenerationConfig, TFGenerationMixin +from .tf_utils import ( + convert_batch_encoding, + expand_1d, + load_attributes_from_hdf5_group, + save_attributes_to_hdf5_group, + shape_list, +) +from .utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + TF2_WEIGHTS_INDEX_NAME, + TF2_WEIGHTS_NAME, + TF_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + ModelOutput, + PushToHubMixin, + cached_file, + download_url, + find_labels, + has_file, + is_offline_mode, + is_remote_url, + is_safetensors_available, + is_tf_symbolic_tensor, + logging, + requires_backends, + working_or_temp_dir, +) +from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files + + +if is_safetensors_available(): + from safetensors import safe_open + from safetensors.tensorflow import save_file as safe_save_file + +if TYPE_CHECKING: + from . import PreTrainedTokenizerBase + +logger = logging.get_logger(__name__) + +if "TF_USE_LEGACY_KERAS" not in os.environ: + os.environ["TF_USE_LEGACY_KERAS"] = "1" # Compatibility fix to make sure tf.keras stays at Keras 2 +elif os.environ["TF_USE_LEGACY_KERAS"] != "1": + logger.warning( + "Transformers is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. " + "This may result in unexpected behaviour or errors if Keras 3 objects are passed to Transformers models." + ) + +try: + import tf_keras as keras + from tf_keras import backend as K +except (ModuleNotFoundError, ImportError): + import keras + from keras import backend as K + + if parse(keras.__version__).major > 2: + raise ValueError( + "Your currently installed version of Keras is Keras 3, but this is not yet supported in " + "Transformers. Please install the backwards-compatible tf-keras package with " + "`pip install tf-keras`." + ) + + +tf_logger = tf.get_logger() + +TFModelInputType = Union[ + list[tf.Tensor], + list[np.ndarray], + dict[str, tf.Tensor], + dict[str, np.ndarray], + tf.Tensor, + np.ndarray, +] + + +def dummy_loss(y_true, y_pred): + if y_pred.shape.rank <= 1: + return y_pred + else: + reduction_axes = list(range(1, y_pred.shape.rank)) + return tf.reduce_mean(y_pred, axis=reduction_axes) + + +class TFModelUtilsMixin: + """ + A few utilities for `keras.Model`, to be used as a mixin. + """ + + def num_parameters(self, only_trainable: bool = False) -> int: + """ + Get the number of (optionally, trainable) parameters in the model. + + Args: + only_trainable (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of trainable parameters + + Returns: + `int`: The number of parameters. + """ + if only_trainable: + return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables)) + else: + return self.count_params() + + +def keras_serializable(cls): + """ + Decorate a Keras Layer class to support Keras serialization. + + This is done by: + + 1. Adding a `transformers_config` dict to the Keras config dictionary in `get_config` (called by Keras at + serialization time. + 2. Wrapping `__init__` to accept that `transformers_config` dict (passed by Keras at deserialization time) and + convert it to a config object for the actual layer initializer. + 3. Registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does not + need to be supplied in `custom_objects` in the call to `keras.models.load_model`. + + Args: + cls (a `keras.layers.Layers subclass`): + Typically a `TF.MainLayer` class in this project, in general must accept a `config` argument to its + initializer. + + Returns: + The same class object, with modifications for Keras deserialization. + """ + initializer = cls.__init__ + + config_class = getattr(cls, "config_class", None) + if config_class is None: + raise AttributeError("Must set `config_class` to use @keras_serializable") + + @functools.wraps(initializer) + def wrapped_init(self, *args, **kwargs): + config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.pop("config", None) + + if isinstance(config, dict): + config = config_class.from_dict(config) + initializer(self, config, *args, **kwargs) + elif isinstance(config, PretrainedConfig): + if len(args) > 0: + initializer(self, *args, **kwargs) + else: + initializer(self, config, *args, **kwargs) + else: + raise TypeError("Must pass either `config` (PretrainedConfig) or `config` (dict)") + + self._config = config + self._kwargs = kwargs + + cls.__init__ = wrapped_init + + if not hasattr(cls, "get_config"): + raise TypeError("Only use @keras_serializable on keras.layers.Layer subclasses") + if hasattr(cls.get_config, "_is_default"): + + def get_config(self): + cfg = super(cls, self).get_config() + cfg["config"] = self._config.to_dict() + cfg.update(self._kwargs) + return cfg + + cls.get_config = get_config + + cls._keras_serializable = True + if hasattr(keras.utils, "register_keras_serializable"): + cls = keras.utils.register_keras_serializable()(cls) + return cls + + +class TFCausalLanguageModelingLoss: + """ + Loss function suitable for causal language modeling (CLM), that is, the task of guessing the next token. + + + + Any label of -100 will be ignored (along with the corresponding logits) in the loss computation. + + + """ + + def hf_compute_loss(self, labels, logits): + loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) + if self.config.tf_legacy_loss: + # make sure only labels that are not equal to -100 affect the loss + active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100) + reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) + labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss) + return loss_fn(labels, reduced_logits) + + # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway + unmasked_loss = loss_fn(tf.nn.relu(labels), logits) + # make sure only labels that are not equal to -100 affect the loss + loss_mask = tf.cast(labels != -100, dtype=unmasked_loss.dtype) + masked_loss = unmasked_loss * loss_mask + reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask) + return tf.reshape(reduced_masked_loss, (1,)) + + +class TFQuestionAnsweringLoss: + """ + Loss function suitable for question answering. + """ + + def hf_compute_loss(self, labels, logits): + loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) + start_loss = loss_fn(labels["start_position"], logits[0]) + end_loss = loss_fn(labels["end_position"], logits[1]) + + return (start_loss + end_loss) / 2.0 + + +class TFTokenClassificationLoss: + """ + Loss function suitable for token classification. + + + + Any label of -100 will be ignored (along with the corresponding logits) in the loss computation. + + + """ + + def hf_compute_loss(self, labels, logits): + loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) + if tf.executing_eagerly(): # Data-dependent conditionals are forbidden in XLA + if tf.math.reduce_any(labels == -1): + tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.") + + if self.config.tf_legacy_loss: + # make sure only labels that are not equal to -100 + # are taken into account as loss + if tf.math.reduce_any(labels == -1): + tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.") + active_loss = tf.reshape(labels, (-1,)) != -1 + else: + active_loss = tf.reshape(labels, (-1,)) != -100 + reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) + labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss) + + return loss_fn(labels, reduced_logits) + + # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway + unmasked_loss = loss_fn(tf.nn.relu(labels), logits) + # make sure only labels that are not equal to -100 or -1 + # are taken into account as loss + loss_mask = tf.cast(labels >= 0, dtype=unmasked_loss.dtype) + # Avoid possible division by zero later + # Masked positions will have a loss of NaN because -100 and -1 are not valid labels + masked_loss = unmasked_loss * loss_mask + reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask) + return tf.reshape(reduced_masked_loss, (1,)) + + +class TFSequenceClassificationLoss: + """ + Loss function suitable for sequence classification. + """ + + def hf_compute_loss(self, labels, logits): + if logits.shape.rank == 1 or logits.shape[1] == 1: + loss_fn = keras.losses.MeanSquaredError(reduction=keras.losses.Reduction.NONE) + if labels.shape.rank == 1: + # MeanSquaredError returns a scalar loss if the labels are 1D, so avoid that + labels = tf.expand_dims(labels, axis=-1) + else: + loss_fn = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction=keras.losses.Reduction.NONE + ) + + return loss_fn(labels, logits) + + +class TFMultipleChoiceLoss: + """Loss function suitable for multiple choice tasks.""" + + def hf_compute_loss(self, labels, logits): + loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) + return loss_fn(labels, logits) + + +class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss): + """ + Loss function suitable for masked language modeling (MLM), that is, the task of guessing the masked tokens. + + + + Any label of -100 will be ignored (along with the corresponding logits) in the loss computation. + + + """ + + +class TFNextSentencePredictionLoss: + """ + Loss function suitable for next sentence prediction (NSP), that is, the task of guessing the next sentence. + + + + Any label of -100 will be ignored (along with the corresponding logits) in the loss computation. + + + """ + + def hf_compute_loss(self, labels, logits): + loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) + if self.config.tf_legacy_loss: + # make sure only labels that are not equal to -100 + # are taken into account as loss + next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100) + next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss) + next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss) + + return loss_fn(next_sentence_label, next_sentence_reduced_logits) + + # make sure only labels that are not equal to -100 + # are taken into account as loss + + # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway + unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels), y_pred=logits) + ns_loss_mask = tf.cast(labels != -100, dtype=unmasked_ns_loss.dtype) + # Just zero out samples where label is -100, no reduction + masked_ns_loss = unmasked_ns_loss * ns_loss_mask + + return masked_ns_loss + + +def booleans_processing(config, **kwargs): + """ + Process the input booleans of each model. + + Args: + config ([`PretrainedConfig`]): + The config of the running model. + **kwargs: + The boolean parameters + + Returns: + A dictionary with the proper values for each boolean + """ + final_booleans = {} + + # Pure conv models (such as ConvNext) do not have `output_attentions`. If the signature has + # `output_attentions`, it will be present here in `kwargs`, even if unset (in that case, as `None`) + if "output_attentions" in kwargs: + final_booleans["output_attentions"] = ( + kwargs["output_attentions"] if kwargs["output_attentions"] is not None else config.output_attentions + ) + final_booleans["output_hidden_states"] = ( + kwargs["output_hidden_states"] if kwargs["output_hidden_states"] is not None else config.output_hidden_states + ) + final_booleans["return_dict"] = kwargs["return_dict"] if kwargs["return_dict"] is not None else config.return_dict + + if "use_cache" in kwargs: + final_booleans["use_cache"] = ( + kwargs["use_cache"] if kwargs["use_cache"] is not None else getattr(config, "use_cache", None) + ) + return final_booleans + + +def unpack_inputs(func): + """ + Decorator that processes the inputs to a Keras layer, passing them to the layer as keyword arguments. This enables + downstream use of the inputs by their variable name, even if they arrive packed as a dictionary in the first input + (common case in Keras). + + Args: + func (`callable`): + The callable function of the TensorFlow model. + + + Returns: + A callable that wraps the original `func` with the behavior described above. + """ + + original_signature = inspect.signature(func) + + @functools.wraps(func) + def run_call_with_unpacked_inputs(self, *args, **kwargs): + # isolates the actual `**kwargs` for the decorated function + kwargs_call = {key: val for key, val in kwargs.items() if key not in dict(original_signature.parameters)} + fn_args_and_kwargs = {key: val for key, val in kwargs.items() if key not in kwargs_call} + fn_args_and_kwargs.update({"kwargs_call": kwargs_call}) + + # move any arg into kwargs, if they exist + fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args))) + + # Encoder Decoder models delegate the application of the configuration options to their inner models. + if "EncoderDecoder" in self.__class__.__name__: + config = None + else: + config = self.config + + unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs) + return func(self, **unpacked_inputs) + + # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This + # function does not follow wrapper chains (i.e. ignores `functools.wraps()`), meaning that without the line below + # Keras would attempt to check the first argument against the literal signature of the wrapper. + run_call_with_unpacked_inputs.__signature__ = original_signature + + return run_call_with_unpacked_inputs + + +def input_processing(func, config, **kwargs): + """ + Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input + has to be named accordingly to the parameters name, i.e. `input_ids = keras.Input(shape=(128,), dtype='int32', + name="input_ids")` otherwise the order of the tensors will not be guaranteed during the training. + + Args: + func (`callable`): + The callable function of the TensorFlow model. + config ([`PretrainedConfig`]): + The config of the running model. + **kwargs: + The inputs of the model. + + Returns: + Two lists, one for the missing layers, and another one for the unexpected layers. + """ + signature = dict(inspect.signature(func).parameters) + has_kwargs = bool(signature.pop("kwargs", None)) + signature.pop("self", None) + parameter_names = list(signature.keys()) + main_input_name = parameter_names[0] + main_input = kwargs.pop(main_input_name, None) + output = {} + allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray) + + if "inputs" in kwargs["kwargs_call"]: + warnings.warn( + "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.", + FutureWarning, + ) + + output["input_ids"] = kwargs["kwargs_call"].pop("inputs") + + if "decoder_cached_states" in kwargs["kwargs_call"]: + warnings.warn( + "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use" + " `past_key_values` instead.", + FutureWarning, + ) + output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states") + + if "past" in kwargs["kwargs_call"] and "past_key_values" in parameter_names: + warnings.warn( + "The `past` argument is deprecated and will be removed in a future version, use `past_key_values`" + " instead.", + FutureWarning, + ) + kwargs["past_key_values"] = kwargs["kwargs_call"].pop("past") + elif "past_key_values" in kwargs["kwargs_call"] and "past" in parameter_names: + kwargs["past"] = kwargs["kwargs_call"].pop("past_key_values") + + if has_kwargs: + output["kwargs"] = kwargs.pop("kwargs_call", {}) + else: + if len(kwargs["kwargs_call"]) > 0: + raise ValueError( + "The following keyword arguments are not supported by this model:" + f" {list(kwargs['kwargs_call'].keys())}." + ) + kwargs.pop("kwargs_call") + + for k, v in kwargs.items(): + if isinstance(v, allowed_types) or tf.is_tensor(v) or v is None: + output[k] = v + else: + raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.") + + if isinstance(main_input, (tuple, list)): + for i, input in enumerate(main_input): + # EagerTensors don't allow to use the .name property so we check for a real Tensor + if is_tf_symbolic_tensor(input): + # Tensor names have always the pattern `name:id` then we check only the + # `name` part + tensor_name = input.name.split(":")[0] + + if tensor_name in parameter_names: + output[tensor_name] = input + else: + output[parameter_names[i]] = input + elif isinstance(input, allowed_types) or input is None: + output[parameter_names[i]] = input + else: + raise ValueError( + f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for" + f" {parameter_names[i]}." + ) + elif isinstance(main_input, Mapping): + if "inputs" in main_input: + warnings.warn( + "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids`" + " instead.", + FutureWarning, + ) + + output["input_ids"] = main_input.pop("inputs") + + if "decoder_cached_states" in main_input: + warnings.warn( + "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use" + " `past_key_values` instead.", + FutureWarning, + ) + output["past_key_values"] = main_input.pop("decoder_cached_states") + + for k, v in dict(main_input).items(): + if isinstance(v, allowed_types) or v is None: + output[k] = v + elif k not in parameter_names and "args" not in parameter_names: + logger.warning( + f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored." + ) + continue + else: + raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.") + else: + if tf.is_tensor(main_input) or main_input is None: + output[main_input_name] = main_input + else: + raise ValueError( + f"Data of type {type(main_input)} is not allowed only {allowed_types} is accepted for" + f" {main_input_name}." + ) + + # Populates any unspecified argument with their default value, according to the signature. + for name in parameter_names: + if name not in list(output.keys()) and name != "args": + output[name] = kwargs.pop(name, signature[name].default) + + # When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs) + # So to respect the proper output we have to add this exception + if "args" in output: + if output["args"] is not None and is_tf_symbolic_tensor(output["args"]): + tensor_name = output["args"].name.split(":")[0] + output[tensor_name] = output["args"] + else: + # `args` in this case is always the first parameter, then `input_ids` + output["input_ids"] = output["args"] + + del output["args"] + + if "kwargs" in output: + del output["kwargs"] + + cast_output = {} + for key, val in output.items(): + if isinstance(val, tf.Tensor) and val.dtype == tf.int64: + cast_output[key] = tf.cast(val, tf.int32) + elif isinstance(val, np.ndarray) and val.dtype == np.int64: + cast_output[key] = val.astype(np.int32) + else: + cast_output[key] = val + + output = cast_output + del cast_output + + if config is not None: + boolean_dict = { + k: v + for k, v in output.items() + if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"] + } + + output.update( + booleans_processing( + config=config, + **boolean_dict, + ) + ) + + return output + + +def strip_model_name_and_prefix(name, _prefix=None): + if _prefix is not None and name.startswith(_prefix): + name = name[len(_prefix) :] + if name.startswith("/"): + name = name[1:] + if "model." not in name and len(name.split("/")) > 1: + name = "/".join(name.split("/")[1:]) + return name + + +def tf_shard_checkpoint(weights, max_shard_size="10GB", weights_name: str = TF2_WEIGHTS_NAME): + """ + Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a + given size. + + The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no + optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the + limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], + [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. + + + + If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will + have a size greater than `max_shard_size`. + + + + Args: + weights (`dict[str, tf.RessourceVariable]`): The list of tf.RessourceVariable of a model to save. + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit + (like `"5MB"`). + """ + max_shard_size = convert_file_size_to_int(max_shard_size) + + sharded_state_dicts = [] + current_block = [] + current_block_size = 0 + total_size = 0 + + for item in weights: + weight_size = item.numpy().size * item.dtype.size + + # If this weight is going to tip up over the maximal size, we split. + if current_block_size + weight_size > max_shard_size: + sharded_state_dicts.append(current_block) + current_block = [] + current_block_size = 0 + + current_block.append(item) + current_block_size += weight_size + total_size += weight_size + + # Add the last block + sharded_state_dicts.append(current_block) + + # If we only have one shard, we return it + if len(sharded_state_dicts) == 1: + return {weights_name: sharded_state_dicts[0]}, None + + # Otherwise, let's build the index + weight_map = {} + shards = {} + for idx, shard in enumerate(sharded_state_dicts): + shard_file = weights_name.replace(".h5", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.h5") + shard_file = shard_file.replace( + ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors" + ) + shards[shard_file] = shard + for weight in shard: + weight_name = weight.name + weight_map[weight_name] = shard_file + + # Add the metadata + metadata = {"total_size": total_size} + index = {"metadata": metadata, "weight_map": weight_map} + return shards, index + + +def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None): + """ + This is the same as `load_tf_weights` but for a sharded checkpoint. Detect missing and unexpected layers and load + the TF weights from the shard file accordingly to their names and shapes. + + This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being + loaded in the model. + + Args: + model (`keras.models.Model`): The model in which to load the checkpoint. + shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names. + ignore_mismatched_sizes`bool`, *optional`, defaults to `True`): + Whether or not to ignore the mismatch between the sizes + strict (`bool`, *optional*, defaults to `True`): + Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint. + + Returns: + Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the + mismatched layers. + """ + + # Load the index + unexpected_keys = set() + saved_keys = set() + mismatched_keys = set() + + # Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load + # the weight, we have to get rid of the first prefix of the name of the layer. + model_keys = set() + model_layer_map = {} + for i, k in enumerate(model.weights): + layer_name = k.name + if _prefix is not None and layer_name.startswith(_prefix): + layer_name = layer_name[len(_prefix) :] + layer_name = layer_name.lstrip("/") + if not ("model." in layer_name or len(layer_name.split("/")) == 1): + layer_name = "/".join(layer_name.split("/")[1:]) + model_keys.add(layer_name) + model_layer_map[layer_name] = i + + for shard_file in shard_files: + saved_weight_names_set, unexpected_keys_set, mismatched_keys_set = load_tf_shard( + model, + model_layer_map, + shard_file, + ignore_mismatched_sizes=ignore_mismatched_sizes, + _prefix=_prefix, + ) + saved_keys.update(saved_weight_names_set) + unexpected_keys.update(unexpected_keys_set) + mismatched_keys.update(mismatched_keys_set) + gc.collect() + + missing_keys = model_keys - saved_keys + if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0): + error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}" + if len(missing_keys) > 0: + str_missing_keys = ",".join([f'"{k}"' for k in missing_keys]) + error_message += f"\nMissing key(s): {str_missing_keys}." + if len(unexpected_keys) > 0: + str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys]) + error_message += f"\nMissing key(s): {str_unexpected_keys}." + raise RuntimeError(error_message) + + return missing_keys, unexpected_keys, mismatched_keys + + +def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): + """ + Loads a shard from a sharded checkpoint file. Can be either H5 or Safetensors. + Handles missing keys and unexpected keys. + + Args: + model (`keras.models.Model`): Model in which the weights are loaded + model_layer_map (`Dict`): A dictionary mapping the layer name to the index of the layer in the model. + resolved_archive_file (`str`): Path to the checkpoint file from which the weights will be loaded + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): Whether to ignore the mismatched keys + + Returns: + `keras.models.Model`: Three lists, one for the layers that were found and successfully restored (from the + shard file), one for the mismatched layers, and another one for the unexpected layers. + """ + saved_weight_names_set = set() + saved_weights = {} + mismatched_keys = set() + unexpected_keys = set() + # Read the H5 file + try: + with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file: + # Retrieve the name of each layer from the H5 file + saved_h5_model_layers_name = set(load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names")) + weight_value_tuples = [] + + # Compute missing and unexpected sub layers + # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...] + for layer_name in saved_h5_model_layers_name: + h5_layer_object = sharded_checkpoint_file[layer_name] + saved_weights[layer_name] = np.asarray(h5_layer_object) + + saved_weight_names_set.add(layer_name) + + if layer_name not in model_layer_map: + unexpected_keys.add(layer_name) + else: + symbolic_weight = model.weights[model_layer_map[layer_name]] + + saved_weight_value = saved_weights[layer_name] + # If the current weight is found + if saved_weight_value is not None: + # Check if the shape of the current weight and the one from the H5 file are different + if K.int_shape(symbolic_weight) != saved_weight_value.shape: + # If yes we reshape the weight from the H5 file accordingly to the current weight + # If the two shapes are not compatible we raise an issue + try: + array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight)) + except ValueError as e: + if ignore_mismatched_sizes: + mismatched_keys.add( + (layer_name, saved_weight_value.shape, K.int_shape(symbolic_weight)) + ) + continue + else: + raise e + else: + array = saved_weight_value + + # We create the tuple that will be loaded and add it to the final list + weight_value_tuples.append((symbolic_weight, array)) + + K.batch_set_value(weight_value_tuples) + + return saved_weight_names_set, unexpected_keys, mismatched_keys + + except Exception as e: + try: + with open(resolved_archive_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {resolved_archive_file} which is necessary to load this pretrained" + " model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from TF checkpoint file for '{resolved_archive_file}' " + f"at '{resolved_archive_file}'. " + "If you tried to load a TF model from a sharded checkpoint, you should try converting the model " + "by loading it in pytorch and saving it locally. A conversion script should be released soon." + ) + + +def load_tf_sharded_weights_from_safetensors( + model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None +): + """ + This is the same as `load_tf_weights_from_safetensors` but for a sharded TF-format safetensors checkpoint. + Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and + shapes. + + This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being + loaded in the model. + + Args: + model (`keras.models.Model`): The model in which to load the checkpoint. + shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names. + ignore_mismatched_sizes`bool`, *optional`, defaults to `True`): + Whether or not to ignore the mismatch between the sizes + strict (`bool`, *optional*, defaults to `True`): + Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint. + + Returns: + Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the + mismatched layers. + """ + + # Load the index + unexpected_keys = set() + all_missing_keys = [] + mismatched_keys = set() + + for shard_file in shard_files: + missing_layers, unexpected_layers, mismatched_layers = load_tf_weights_from_safetensors( + model, + shard_file, + ignore_mismatched_sizes=ignore_mismatched_sizes, + _prefix=_prefix, + ) + all_missing_keys.append(set(missing_layers)) + unexpected_keys.update(unexpected_layers) + mismatched_keys.update(mismatched_layers) + gc.collect() + missing_keys = set.intersection(*all_missing_keys) + + if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0): + error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}" + if len(missing_keys) > 0: + str_missing_keys = ",".join([f'"{k}"' for k in missing_keys]) + error_message += f"\nMissing key(s): {str_missing_keys}." + if len(unexpected_keys) > 0: + str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys]) + error_message += f"\nMissing key(s): {str_unexpected_keys}." + raise RuntimeError(error_message) + + return missing_keys, unexpected_keys, mismatched_keys + + +def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): + """ + Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and + shapes. + + Args: + model (`keras.models.Model`): + The model to load the weights into. + resolved_archive_file (`str`): + The location of the H5 file. + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): + Whether or not to ignore weights with shapes that don't match between the checkpoint of the model. + + Returns: + Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the + mismatched layers. + """ + if resolved_archive_file.endswith(".safetensors"): + load_function = load_tf_weights_from_safetensors + else: + load_function = load_tf_weights_from_h5 + + return load_function( + model, resolved_archive_file, ignore_mismatched_sizes=ignore_mismatched_sizes, _prefix=_prefix + ) + + +def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): + mismatched_layers = [] + + # Read the H5 file + with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file: + # Retrieve the name of each layer from the H5 file + saved_h5_model_layers_name = set(load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names")) + + # Find the missing layers from the high level list of layers + missing_layers = list({layer.name for layer in model.layers} - saved_h5_model_layers_name) + + # Find the unexpected layers from the high level list of layers + unexpected_layers = list(saved_h5_model_layers_name - {layer.name for layer in model.layers}) + saved_weight_names_set = set() + symbolic_weights_names = set() + weight_value_tuples = [] + + # Compute missing and unexpected sub layers + # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...] + for layer in model.layers: + # if layer_name from the H5 file belongs to the layers from the instantiated model + if layer.name in saved_h5_model_layers_name: + # Get the H5 layer object from its name + h5_layer_object = sharded_checkpoint_file[layer.name] + # Get all the weights as a list from the layer object + symbolic_weights = layer.trainable_weights + layer.non_trainable_weights + saved_weights = {} + + # Create a dict from the H5 saved model that looks like {"weight_name": weight_value} + # And a set with only the names + for weight_name in load_attributes_from_hdf5_group(h5_layer_object, "weight_names"): + # TF names always start with the model name so we ignore it + name = "/".join(weight_name.split("/")[1:]) + + if _prefix is not None: + name = _prefix + "/" + name + + saved_weights[name] = np.asarray(h5_layer_object[weight_name]) + + # Add the updated name to the final list for computing missing/unexpected values + saved_weight_names_set.add(name) + + # Loop over each weights from the instantiated model and compare with the weights from the H5 file + for symbolic_weight in symbolic_weights: + # TF names always start with the model name so we ignore it + if _prefix is not None: + delimiter = len(_prefix.split("/")) + symbolic_weight_name = "/".join( + symbolic_weight.name.split("/")[:delimiter] + + symbolic_weight.name.split("/")[delimiter + 1 :] + ) + else: + symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:]) + + # here we check if the current weight is among the weights from the H5 file + # If yes, get the weight_value of the corresponding weight from the H5 file + # If not, make the value to None + saved_weight_value = saved_weights.get(symbolic_weight_name) + + # Retrocompatibility patch: some embeddings are stored with the weights name (e.g. Bart's + # `model.shared/embeddings:0` are stored as `model.shared/weights:0`) + if saved_weight_value is None and symbolic_weight_name.endswith("embeddings:0"): + symbolic_weight_name = symbolic_weight_name[:-12] + "weight:0" + saved_weight_value = saved_weights.get(symbolic_weight_name) + + # Add the updated name to the final list for computing missing/unexpected values + symbolic_weights_names.add(symbolic_weight_name) + + # If the current weight is found + if saved_weight_value is not None: + # Check if the shape of the current weight and the one from the H5 file are different + if K.int_shape(symbolic_weight) != saved_weight_value.shape: + # If yes we reshape the weight from the H5 file accordingly to the current weight + # If the two shapes are not compatible we raise an issue + try: + array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight)) + except ValueError as e: + if ignore_mismatched_sizes: + mismatched_layers.append( + (symbolic_weight_name, saved_weight_value.shape, K.int_shape(symbolic_weight)) + ) + continue + else: + raise e + else: + array = saved_weight_value + + # We create the tuple that will be loaded and add it to the final list + weight_value_tuples.append((symbolic_weight, array)) + + # Load all the weights + K.batch_set_value(weight_value_tuples) + + # Compute the missing and unexpected layers + missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set)) + unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names)) + + return missing_layers, unexpected_layers, mismatched_layers + + +def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): + # Read the safetensors file + with safe_open(resolved_archive_file, framework="tf") as safetensors_archive: + mismatched_layers = [] + weight_names = [strip_model_name_and_prefix(w.name, _prefix=_prefix) for w in model.weights] + loaded_weight_names = list(safetensors_archive.keys()) + # Find the missing layers from the high level list of layers + missing_layers = list(set(weight_names) - set(loaded_weight_names)) + # Find the unexpected layers from the high level list of layers + unexpected_layers = list(set(loaded_weight_names) - set(weight_names)) + + for weight in model.weights: + weight_name = strip_model_name_and_prefix(weight.name, _prefix=_prefix) + if weight_name in loaded_weight_names: + weight_value = safetensors_archive.get_tensor(weight_name) + # Check if the shape of the current weight and the one from the H5 file are different + if K.int_shape(weight) != weight_value.shape: + # If yes we reshape the weight from the H5 file accordingly to the current weight + # If the two shapes are not compatible we raise an issue + try: + weight_value = tf.reshape(weight_value, K.int_shape(weight)) + except (ValueError, tf.errors.InvalidArgumentError) as e: + if ignore_mismatched_sizes: + mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight))) + continue + else: + raise e + + K.set_value(weight, weight_value) # weight.assign() might break if weight is a DTensor + return missing_layers, unexpected_layers, mismatched_layers + + +def init_copy_embeddings(old_embeddings, new_num_tokens): + r""" + This function aims to reduce the embeddings in case new_num_tokens < old_num_tokens or to pad with -1 in case + new_num_tokens > old_num_tokens. A mask is also computed in order to know which weight in the embeddings should be + kept or not. Example: + + - if new_num_tokens=5 and old_num_tokens=4 and old_embeddings=[w1,w2,w3,w4] + + - mask=[True,True,True,True,False] and current_weights=[w1,w2,w3,w4,-1] + - if new_num_tokens=4 and old_num_tokens=5 and old_embeddings=[w1,w2,w3,w4,w5] + + - mask=[True,True,True,True] and current_weights=[w1,w2,w3,w4] + """ + old_num_tokens, old_embedding_dim = shape_list(old_embeddings) + size_diff = new_num_tokens - old_num_tokens + + # initialize new embeddings + # Copy token embeddings from the previous ones + if tf.math.greater(size_diff, 0): + # if the new size is greater than the old one, we extend the current embeddings with a padding until getting new size + # and we create a mask to properly identify the padded values and be replaced by the values of the newly created + # embeddings + current_weights = tf.pad( + old_embeddings.value(), tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=-1 + ) + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + mask = tf.fill(tf.convert_to_tensor([num_tokens_to_copy, 1]), True) + mask = tf.pad(mask, tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=False) + else: + # if the new size if lower than the old one, we take the current embeddings until the new size + current_weights = tf.slice( + old_embeddings.value(), + tf.convert_to_tensor([0, 0]), + tf.convert_to_tensor([new_num_tokens, old_embedding_dim]), + ) + mask = tf.fill(tf.convert_to_tensor([new_num_tokens, 1]), True) + + return mask, current_weights + + +class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushToHubMixin): + r""" + Base class for all TF models. + + [`TFPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading, + downloading and saving models as well as a few methods common to all models to: + + - resize the input embeddings, + - prune heads in the self-attention heads. + + Class attributes (overridden by derived classes): + + - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class + for this model architecture. + - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived + classes of the same architecture adding modules on top of the base model. + - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP + models, `pixel_values` for vision models and `input_values` for speech models). + """ + + config_class = None + base_model_prefix = "" + main_input_name = "input_ids" + _auto_class = None + _using_dummy_loss = None + _label_to_output_map = None + + # a list of re pattern of tensor names to ignore from the model when loading the model weights + # (and avoid unnecessary warnings). + _keys_to_ignore_on_load_missing = None + # a list of re pattern of tensor names to ignore from the weights when loading the model weights + # (and avoid unnecessary warnings). + _keys_to_ignore_on_load_unexpected = None + _requires_load_weight_prefix = False + + @property + def dummy_inputs(self) -> dict[str, tf.Tensor]: + """ + Dummy inputs to build the network. + + Returns: + `dict[str, tf.Tensor]`: The dummy inputs. + """ + dummies = {} + for key, spec in self.input_signature.items(): + # 2 is the most correct arbitrary size. I will not be taking questions + dummy_shape = [dim if dim is not None else 2 for dim in spec.shape] + if spec.shape[0] is None: + # But let's make the batch size 1 to save memory anyway + dummy_shape[0] = 1 + dummies[key] = tf.ones(shape=dummy_shape, dtype=spec.dtype) + if key == "token_type_ids": + # Some models have token_type_ids but with a vocab_size of 1 + dummies[key] = tf.zeros_like(dummies[key]) + if self.config.add_cross_attention and "encoder_hidden_states" in inspect.signature(self.call).parameters: + if "encoder_hidden_states" not in dummies: + if self.main_input_name == "input_ids": + dummies["encoder_hidden_states"] = tf.ones( + shape=(1, 2, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states" + ) + else: + raise NotImplementedError( + "Model has cross-attention but we couldn't infer the shape for the encoder hidden states. Please manually override dummy_inputs!" + ) + return dummies + + def build_in_name_scope(self): + with tf.name_scope(self.name): + self.build(input_shape=None) + + @property + def framework(self) -> str: + """ + :str: Identifies that this is a TensorFlow model. + """ + return "tf" + + def build(self, input_shape=None): + pass # This is just here to make sure we don't call the superclass build() + + def __init__(self, config, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + if not isinstance(config, PretrainedConfig): + raise TypeError( + f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class " + "`PretrainedConfig`. To create a model from a pretrained model use " + f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + # Save config and origin of the pretrained weights if given in model + self.config = config + self.name_or_path = config.name_or_path + self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None + self._set_save_spec(self.input_signature) + logger.warning_once( + "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We " + "recommend migrating to PyTorch classes or pinning your version of Transformers." + ) + + def get_config(self): + return self.config.to_dict() + + @functools.wraps(keras.Model.fit) + def fit(self, *args, **kwargs): + args, kwargs = convert_batch_encoding(*args, **kwargs) + return super().fit(*args, **kwargs) + + @functools.wraps(keras.Model.train_on_batch) + def train_on_batch(self, *args, **kwargs): + args, kwargs = convert_batch_encoding(*args, **kwargs) + return super().train_on_batch(*args, **kwargs) + + @functools.wraps(keras.Model.test_on_batch) + def test_on_batch(self, *args, **kwargs): + args, kwargs = convert_batch_encoding(*args, **kwargs) + return super().test_on_batch(*args, **kwargs) + + @functools.wraps(keras.Model.predict_on_batch) + def predict_on_batch(self, *args, **kwargs): + args, kwargs = convert_batch_encoding(*args, **kwargs) + return super().predict_on_batch(*args, **kwargs) + + @functools.wraps(keras.Model.predict) + def predict(self, *args, **kwargs): + args, kwargs = convert_batch_encoding(*args, **kwargs) + return super().predict(*args, **kwargs) + + @functools.wraps(keras.Model.evaluate) + def evaluate(self, *args, **kwargs): + args, kwargs = convert_batch_encoding(*args, **kwargs) + return super().evaluate(*args, **kwargs) + + @classmethod + def from_config(cls, config, **kwargs): + if isinstance(config, PretrainedConfig): + return cls._from_config(config, **kwargs) + return cls._from_config(cls.config_class.from_dict(config, **kwargs)) + + @classmethod + def _from_config(cls, config, **kwargs): + """ + All context managers that the model should be initialized under go here. + """ + return cls(config, **kwargs) + + def get_head_mask(self, head_mask: tf.Tensor | None, num_hidden_layers: int) -> tf.Tensor: + """ + Prepare the head mask if needed. + + Args: + head_mask (`tf.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*): + The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard). + num_hidden_layers (`int`): + The number of hidden layers in the model. + + Returns: + `tf.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with + `[None]` for each layer. + """ + if head_mask is not None: + head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) + else: + head_mask = [None] * num_hidden_layers + + return head_mask + + def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): + """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]""" + if head_mask.shape.rank == 1: + head_mask = head_mask[None, None, :, None, None] + head_mask = tf.repeat(head_mask, repeats=num_hidden_layers, axis=0) + elif head_mask.shape.rank == 2: + head_mask = head_mask[:, None, :, None, None] + assert head_mask.shape.rank == 5, f"head_mask.dim != 5, instead {head_mask.dim()}" + head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility + return head_mask + + @tf.function + def serving(self, inputs): + """ + Args: + Method used for serving the model. Does not have a specific signature, but will be specialized as concrete + functions when saving with `save_pretrained`. + inputs (`dict[str, tf.Tensor]`): + The input of the saved model as a dictionary of tensors. + """ + output = self.call(inputs) + + return self.serving_output(output) + + @property + def input_signature(self) -> dict[str, tf.TensorSpec]: + """ + This property should return a dict mapping input names to tf.TensorSpec objects, representing the expected + shape and dtype for model inputs. It is used for both serving and for generating dummy inputs. + """ + model_inputs = list(inspect.signature(self.call).parameters) + sig = {} + if "input_ids" in model_inputs: + if self.__class__.__name__.endswith("ForMultipleChoice"): + text_dims = 3 + else: + text_dims = 2 + for input_name in ( + "input_ids", + "attention_mask", + "token_type_ids", + "decoder_input_ids", + "decoder_attention_mask", + ): + if input_name in model_inputs: + sig[input_name] = tf.TensorSpec([None] * text_dims, tf.int32, name=input_name) + if "pixel_values" in model_inputs: + pixel_values_shape = [None, None, None, None] + if hasattr(self.config, "vision_config"): + vision_config = self.config.vision_config + else: + vision_config = self.config + if hasattr(vision_config, "num_channels"): + pixel_values_shape[1] = vision_config.num_channels + else: + raise NotImplementedError( + "Could not infer number of channels from config, please override input_signature to specify input shapes." + ) + if hasattr(vision_config, "image_size"): + pixel_values_shape[2] = pixel_values_shape[3] = vision_config.image_size + elif hasattr(vision_config, "input_size"): + pixel_values_shape[2] = pixel_values_shape[3] = vision_config.input_size + else: + raise NotImplementedError( + "Could not infer input image shape from config, please override input_signature to specify input shapes." + ) + sig["pixel_values"] = tf.TensorSpec(pixel_values_shape, tf.float32, name="pixel_values") + if "input_features" in model_inputs: + raise NotImplementedError("Audio models need a manually defined input_signature") + return sig + + def serving_output(self, output): + """ + Prepare the output of the saved model. Can be overridden if specific serving modifications are required. + """ + if not isinstance(output, ModelOutput): + return output + for key in output: + if key.endswith("hidden_states") and not getattr(self.config, "output_hidden_states", False): + output[key] = None + elif key.endswith("attentions") and not getattr(self.config, "output_attentions", False): + output[key] = None + elif key == "past_key_values" and not getattr(self.config, "use_cache", False): + output[key] = None + elif key == "cross_attentions" and not ( + getattr(self.config, "output_attentions", False) and getattr(self.config, "add_cross_attention", False) + ): + output[key] = None + if isinstance(output[key], (tuple, list)): + try: + output[key] = tf.convert_to_tensor(output[key]) + except (ValueError, tf.errors.InvalidArgumentError): + pass # Layers may not have the same dimensions + return output + + @classmethod + def can_generate(cls) -> bool: + """ + Returns whether this model can generate sequences with `.generate()`. + + Returns: + `bool`: Whether this model can generate sequences with `.generate()`. + """ + # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation. + # Alternatively, the model can also have a custom `generate` function. + if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate): + return False + return True + + def get_input_embeddings(self) -> keras.layers.Layer: + """ + Returns the model's input embeddings layer. + + Returns: + `tf.Variable`: The embeddings layer mapping vocabulary to hidden states. + """ + main_layer = getattr(self, self.base_model_prefix, self) + + if main_layer is not self: + return main_layer.get_input_embeddings() + else: + raise NotImplementedError + + def _save_checkpoint(self, checkpoint_dir, epoch): + if not os.path.isdir(checkpoint_dir): + os.mkdir(checkpoint_dir) + # We avoid tf.train.checkpoint or saving weights in TF format, even though that includes optimizer + # state for us, because it requires special handling for objects like custom losses, which we use + # internally and which users are likely to use too + weights_path = os.path.join(checkpoint_dir, "weights.h5") + self.save_weights(weights_path) + extra_data = {"epoch": epoch, "optimizer_state": self.optimizer.get_weights()} + extra_data_path = os.path.join(checkpoint_dir, "extra_data.pickle") + with open(extra_data_path, "wb") as f: + pickle.dump(extra_data, f) + + def prepare_tf_dataset( + self, + dataset: datasets.Dataset, # noqa:F821 + batch_size: int = 8, + shuffle: bool = True, + tokenizer: PreTrainedTokenizerBase | None = None, + collate_fn: Callable | None = None, + collate_fn_args: dict[str, Any] | None = None, + drop_remainder: bool | None = None, + prefetch: bool = True, + ): + """ + Wraps a HuggingFace [`~datasets.Dataset`] as a `tf.data.Dataset` with collation and batching. This method is + designed to create a "ready-to-use" dataset that can be passed directly to Keras methods like `fit()` without + further modification. The method will drop columns from the dataset if they don't match input names for the + model. If you want to specify the column names to return rather than using the names that match this model, we + recommend using `Dataset.to_tf_dataset()` instead. + + Args: + dataset (`Any`): + A [~`datasets.Dataset`] to be wrapped as a `tf.data.Dataset`. + batch_size (`int`, *optional*, defaults to 8): + The size of batches to return. + shuffle (`bool`, defaults to `True`): + Whether to return samples from the dataset in random order. Usually `True` for training datasets and + `False` for validation/test datasets. + tokenizer ([`PreTrainedTokenizerBase`], *optional*): + A `PreTrainedTokenizer` that will be used to pad samples to create batches. Has no effect if a specific + `collate_fn` is passed instead. + collate_fn (`Callable`, *optional*): + A function that collates samples from the dataset into a single batch. Defaults to + `DefaultDataCollator` if no `tokenizer` is supplied or `DataCollatorWithPadding` if a `tokenizer` is + passed. + collate_fn_args (`dict[str, Any]`, *optional*): + A dict of arguments to pass to the `collate_fn` alongside the list of samples. + drop_remainder (`bool`, *optional*): + Whether to drop the final batch, if the batch_size does not evenly divide the dataset length. Defaults + to the same setting as `shuffle`. + prefetch (`bool`, defaults to `True`): + Whether to add prefetching to the end of the `tf.data` pipeline. This is almost always beneficial for + performance, but can be disabled in edge cases. + + + Returns: + `Dataset`: A `tf.data.Dataset` which is ready to pass to the Keras API. + """ + requires_backends(self, ["datasets"]) + import datasets + + if collate_fn is None: + if tokenizer is None: + collate_fn = DefaultDataCollator(return_tensors="np") + else: + collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="np") + if collate_fn_args is None: + collate_fn_args = {} + + if not isinstance(dataset, datasets.Dataset): + raise TypeError("Dataset argument should be a datasets.Dataset!") + model_inputs = list(inspect.signature(self.call).parameters) + model_labels = find_labels(self.__class__) + if "cols_to_retain" in list(inspect.signature(dataset._get_output_signature).parameters.keys()): + output_signature, _ = dataset._get_output_signature( + dataset, + batch_size=None, + collate_fn=collate_fn, + collate_fn_args=collate_fn_args, + cols_to_retain=model_inputs, + ) + else: + # TODO Matt: This is a workaround for older versions of datasets that are missing the `cols_to_retain` + # argument. We should remove this once the minimum supported version of datasets is > 2.3.2 + unwanted_columns = [ + feature + for feature in dataset.features + if feature not in model_inputs and feature not in ("label_ids", "label") + ] + dataset = dataset.remove_columns(unwanted_columns) + output_signature, _ = dataset._get_output_signature( + dataset, batch_size=None, collate_fn=collate_fn, collate_fn_args=collate_fn_args + ) + output_columns = list(output_signature.keys()) + feature_cols = [col for col in output_columns if col in model_inputs and col not in model_labels] + label_cols = [col for col in output_columns if col in model_labels] + + # Backwards compatibility for older versions of datasets. Previously, if `columns` or `label_cols` + # were a single element list, the returned element spec would be a single element. Now, passing [feature] + # will return a dict structure {"feature": feature}, and passing a single string will return a single element. + feature_cols = feature_cols[0] if len(feature_cols) == 1 else feature_cols + label_cols = label_cols[0] if len(label_cols) == 1 else label_cols + + if drop_remainder is None: + drop_remainder = shuffle + tf_dataset = dataset.to_tf_dataset( + columns=feature_cols, + label_cols=label_cols, + batch_size=batch_size, + shuffle=shuffle, + drop_remainder=drop_remainder, + collate_fn=collate_fn, + collate_fn_args=collate_fn_args, + prefetch=prefetch, + ) + return tf_dataset + + def compile( + self, + optimizer="rmsprop", + loss="auto_with_warning", + metrics=None, + loss_weights=None, + weighted_metrics=None, + run_eagerly=None, + steps_per_execution=None, + **kwargs, + ): + """ + This is a thin wrapper that sets the model's loss output head as the loss if the user does not specify a loss + function themselves. + """ + if loss in ("auto_with_warning", "passthrough"): # "passthrough" for workflow backward compatibility + logger.info( + "No loss specified in compile() - the model's internal loss computation will be used as the " + "loss. Don't panic - this is a common way to train TensorFlow models in Transformers! " + "To disable this behaviour please pass a loss argument, or explicitly pass " + "`loss=None` if you do not want your model to compute a loss. You can also specify `loss='auto'` to " + "get the internal loss without printing this info string." + ) + loss = "auto" + if loss == "auto": + loss = dummy_loss + self._using_dummy_loss = True + else: + self._using_dummy_loss = False + parent_args = list(inspect.signature(keras.Model.compile).parameters.keys()) + # This argument got renamed, we need to support both versions + if "steps_per_execution" in parent_args: + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + loss_weights=loss_weights, + weighted_metrics=weighted_metrics, + run_eagerly=run_eagerly, + steps_per_execution=steps_per_execution, + **kwargs, + ) + else: + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + loss_weights=loss_weights, + weighted_metrics=weighted_metrics, + run_eagerly=run_eagerly, + experimental_steps_per_execution=steps_per_execution, + **kwargs, + ) + + def compute_loss(self, *args, **kwargs): + if hasattr(keras.Model, "compute_loss"): + # This will be true in TF 2.8 or greater + return super().compute_loss(*args, **kwargs) + else: + warnings.warn( + "The old compute_loss method is deprecated as it conflicts with the Keras compute_loss " + "method added in TF 2.8. If you want the original HF compute_loss, please call " + "hf_compute_loss() instead. From TF versions >= 2.8, or Transformers versions >= 5, " + "calling compute_loss() will get the Keras method instead.", + FutureWarning, + ) + return self.hf_compute_loss(*args, **kwargs) + + def get_label_to_output_name_mapping(self): + arg_names = list(inspect.signature(self.call).parameters) + if self._label_to_output_map is not None: + return self._label_to_output_map + elif "start_positions" in arg_names: + return {"start_positions": "start_logits", "end_positions": "end_logits"} + elif "sentence_order_label" in arg_names: + return {"labels": "prediction_logits", "sentence_order_label": "sop_logits"} + elif "next_sentence_label" in arg_names: + return {"labels": "prediction_logits", "next_sentence_label": "seq_relationship_logits"} + elif "mc_labels" in arg_names: + return {"labels": "logits", "mc_labels": "mc_logits"} + else: + return {} + + def train_step(self, data): + """ + A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models + and supports directly training on the loss output head. In addition, it ensures input keys are copied to the + labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure + that they are available to the model during the forward pass. + """ + + # We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map` + arg_names = list(inspect.signature(self.call).parameters) + label_kwargs = find_labels(self.__class__) + label_to_output = self.get_label_to_output_name_mapping() + output_to_label = {val: key for key, val in label_to_output.items()} + if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"): + # Newer TF train steps leave this out + data = expand_1d(data) + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) + # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify + # them during input/label pre-processing. This avoids surprising the user by wrecking their data. + # In addition, modifying mutable Python inputs makes XLA compilation impossible. + if isinstance(x, dict): + x = x.copy() + if isinstance(y, dict): + y = y.copy() + + # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments, + # if those keys are not already present in the input dict + if self._using_dummy_loss and y is not None: + # If y is a tensor and the model only has one label-like input, map y to that input + if len(label_kwargs) == 1 and isinstance(y, tf.Tensor): + if isinstance(x, tf.Tensor): + x = {arg_names[0]: x} + label_kwarg = next(iter(label_kwargs)) + if label_kwarg not in x: + x[label_kwarg] = y + # Otherwise, copy keys from y to x as long as they weren't already present in x + elif isinstance(y, dict): + if isinstance(x, tf.Tensor): + x = {arg_names[0]: x} + for key, val in y.items(): + if key in arg_names and key not in x: + x[key] = val + elif output_to_label.get(key) in arg_names and key not in x: + x[output_to_label[key]] = val + if y is None: + y = {key: val for key, val in x.items() if key in label_kwargs} + if not y and not self._using_dummy_loss: + raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!") + + if isinstance(y, dict): + # Rename labels at this point to match output heads + y = {label_to_output.get(key, key): val for key, val in y.items()} + + # Run forward pass. + with tf.GradientTape() as tape: + if self._using_dummy_loss and "return_loss" in arg_names: + y_pred = self(x, training=True, return_loss=True) + else: + y_pred = self(x, training=True) + if self._using_dummy_loss: + loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses) + else: + loss = None + + # This next block matches outputs to label keys. Tensorflow's standard method for doing this + # can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors) + if isinstance(y, dict) and len(y) == 1: + if list(y.keys())[0] in y_pred: + y_pred = y_pred[list(y.keys())[0]] + elif list(y_pred.keys())[0] == "loss": + y_pred = y_pred[1] + else: + y_pred = y_pred[0] + _, y = y.popitem() + elif isinstance(y, dict): + # If the labels are a dict, match keys from the output by name + y_pred = {key: val for key, val in y_pred.items() if key in y} + elif isinstance(y, (tuple, list)): + # If the labels are a tuple/list, match keys to the output by order, skipping the loss. + if list(y_pred.keys())[0] == "loss": + y_pred = y_pred.to_tuple()[1:] + else: + y_pred = y_pred.to_tuple() + y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems + else: + # If the labels are a single tensor, match them to the first non-loss tensor in the output + if list(y_pred.keys())[0] == "loss": + y_pred = y_pred[1] + else: + y_pred = y_pred[0] + + if loss is None: + loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) + + # Run backwards pass. + self.optimizer.minimize(loss, self.trainable_variables, tape=tape) + + self.compiled_metrics.update_state(y, y_pred, sample_weight) + # Collect metrics to return + return_metrics = {} + for metric in self.metrics: + result = metric.result() + if isinstance(result, dict): + return_metrics.update(result) + else: + return_metrics[metric.name] = result + return return_metrics + + def test_step(self, data): + """ + A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models + and supports directly training on the loss output head. In addition, it ensures input keys are copied to the + labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure + that they are available to the model during the forward pass. + """ + # We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map` + arg_names = list(inspect.signature(self.call).parameters) + label_kwargs = find_labels(self.__class__) + label_to_output = self.get_label_to_output_name_mapping() + output_to_label = {val: key for key, val in label_to_output.items()} + if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"): + # Newer versions leave this out + data = expand_1d(data) + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) + # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify + # them during input/label pre-processing. This avoids surprising the user by wrecking their data. + # In addition, modifying mutable Python inputs makes XLA compilation impossible. + if isinstance(x, dict): + x = x.copy() + if isinstance(y, dict): + y = y.copy() + + # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments, + # if those keys are not already present in the input dict + if self._using_dummy_loss and y is not None: + arg_names = list(inspect.signature(self.call).parameters) + # If y is a tensor and the model only has one label-like input, map y to that input + if len(label_kwargs) == 1 and isinstance(y, tf.Tensor): + if isinstance(x, tf.Tensor): + x = {arg_names[0]: x} + label_kwarg = next(iter(label_kwargs)) + if label_kwarg not in x: + x[label_kwarg] = y + # Otherwise, copy keys from y to x as long as they weren't already present in x + elif isinstance(y, dict): + if isinstance(x, tf.Tensor): + x = {arg_names[0]: x} + for key, val in y.items(): + if key in arg_names and key not in x: + x[key] = val + elif output_to_label.get(key) in arg_names and key not in x: + x[output_to_label[key]] = val + if y is None: + y = {key: val for key, val in x.items() if key in label_kwargs} + if not y and not self._using_dummy_loss: + raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!") + + if isinstance(y, dict): + # Rename labels at this point to match output heads + y = {label_to_output.get(key, key): val for key, val in y.items()} + + # Run forward pass. + if self._using_dummy_loss and "return_loss" in arg_names: + y_pred = self(x, return_loss=True, training=False) + else: + y_pred = self(x, training=False) + if self._using_dummy_loss: + loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses) + else: + loss = None + + # This next block matches outputs to label keys. Tensorflow's standard method for doing this + # can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors) + if isinstance(y, dict) and len(y) == 1: + if list(y.keys())[0] in y_pred: + y_pred = y_pred[list(y.keys())[0]] + elif list(y_pred.keys())[0] == "loss": + y_pred = y_pred[1] + else: + y_pred = y_pred[0] + _, y = y.popitem() + elif isinstance(y, dict): + # If the labels are a dict, match keys from the output by name + y_pred = {key: val for key, val in y_pred.items() if key in y} + elif isinstance(y, (tuple, list)): + # If the labels are a tuple/list, match keys to the output by order, skipping the loss. + if list(y_pred.keys())[0] == "loss": + y_pred = y_pred.to_tuple()[1:] + else: + y_pred = y_pred.to_tuple() + y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems + else: + # If the labels are a single tensor, match them to the first non-loss tensor in the output + if list(y_pred.keys())[0] == "loss": + y_pred = y_pred[1] + else: + y_pred = y_pred[0] + + if loss is None: + loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) + + self.compiled_metrics.update_state(y, y_pred, sample_weight) + # Collect metrics to return + return_metrics = {} + for metric in self.metrics: + result = metric.result() + if isinstance(result, dict): + return_metrics.update(result) + else: + return_metrics[metric.name] = result + return return_metrics + + def create_model_card( + self, + output_dir, + model_name: str, + language: str | None = None, + license: str | None = None, + tags: str | None = None, + finetuned_from: str | None = None, + tasks: str | None = None, + dataset_tags: str | list[str] | None = None, + dataset: str | list[str] | None = None, + dataset_args: str | list[str] | None = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + output_dir (`str` or `os.PathLike`): + The folder in which to create the model card. + model_name (`str`, *optional*): + The name of the model. + language (`str`, *optional*): + The language of the model (if applicable) + license (`str`, *optional*): + The license of the model. Will default to the license of the pretrained model used, if the original + model given to the `Trainer` comes from a repo on the Hub. + tags (`str` or `list[str]`, *optional*): + Some tags to be included in the metadata of the model card. + finetuned_from (`str`, *optional*): + The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo + of the original model given to the `Trainer` (if it comes from the Hub). + tasks (`str` or `list[str]`, *optional*): + One or several task identifiers, to be included in the metadata of the model card. + dataset_tags (`str` or `list[str]`, *optional*): + One or several dataset tags, to be included in the metadata of the model card. + dataset (`str` or `list[str]`, *optional*): + One or several dataset identifiers, to be included in the metadata of the model card. + dataset_args (`str` or `list[str]`, *optional*): + One or several dataset arguments, to be included in the metadata of the model card. + """ + # Avoids a circular import by doing this when necessary. + from .modelcard import TrainingSummary # tests_ignore + + training_summary = TrainingSummary.from_keras( + self, + keras_history=self.history, + language=language, + license=license, + tags=tags, + model_name=model_name, + finetuned_from=finetuned_from, + tasks=tasks, + dataset_tags=dataset_tags, + dataset=dataset, + dataset_args=dataset_args, + ) + model_card = training_summary.to_model_card() + with open(os.path.join(output_dir, "README.md"), "w") as f: + f.write(model_card) + + def set_input_embeddings(self, value): + """ + Set model's input embeddings + + Args: + value (`tf.Variable`): + The new weights mapping hidden states to vocabulary. + """ + main_layer = getattr(self, self.base_model_prefix) + + if main_layer is None: + raise NotImplementedError("The model does not implements the base_model_prefix attribute.") + + try: + main_layer.set_input_embeddings(value) + except AttributeError: + logger.info("Building the model") + self.build_in_name_scope() + main_layer.set_input_embeddings(value) + + def get_output_embeddings(self) -> None | keras.layers.Layer: + """ + Returns the model's output embeddings + + Returns: + `tf.Variable`: The new weights mapping vocabulary to hidden states. + """ + if self.get_lm_head() is not None: + lm_head = self.get_lm_head() + + try: + return lm_head.get_output_embeddings() + except AttributeError: + logger.info("Building the model") + self.build_in_name_scope() + + return lm_head().get_output_embeddings() + + return None # Overwrite for models with output embeddings + + def set_output_embeddings(self, value): + """ + Set model's output embeddings + + Args: + value (`tf.Variable`): + The new weights mapping hidden states to vocabulary. + """ + if self.get_lm_head() is not None: + lm_head = self.get_lm_head() + try: + lm_head.set_output_embeddings(value) + except AttributeError: + logger.info("Building the model") + self.build_in_name_scope() + lm_head.set_output_embeddings(value) + + def get_output_layer_with_bias(self) -> None | keras.layers.Layer: + """ + Get the layer that handles a bias attribute in case the model has an LM head with weights tied to the + embeddings + + Return: + `keras.layers.Layer`: The layer that handles the bias, None if not an LM model. + """ + warnings.warn( + "The method get_output_layer_with_bias is deprecated. Please use `get_lm_head` instead.", FutureWarning + ) + return self.get_lm_head() + + def get_prefix_bias_name(self) -> None | str: + """ + Get the concatenated _prefix name of the bias from the model name to the parent layer + + Return: + `str`: The _prefix name of the bias. + """ + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return None + + def get_bias(self) -> None | dict[str, tf.Variable]: + """ + Dict of bias attached to an LM head. The key represents the name of the bias attribute. + + Return: + `tf.Variable`: The weights representing the bias, None if not an LM model. + """ + if self.get_lm_head() is not None: + lm_head = self.get_lm_head() + try: + return lm_head.get_bias() + except AttributeError: + self.build_in_name_scope() + + return lm_head.get_bias() + return None + + def set_bias(self, value): + """ + Set all the bias in the LM head. + + Args: + value (`dict[tf.Variable]`): + All the new bias attached to an LM head. + """ + if self.get_lm_head() is not None: + lm_head = self.get_lm_head() + try: + lm_head.set_bias(value) + except AttributeError: + self.build_in_name_scope() + lm_head.set_bias(value) + + def get_lm_head(self) -> keras.layers.Layer: + """ + The LM Head layer. This method must be overwritten by all the models that have a lm head. + + Return: + `keras.layers.Layer`: The LM head layer if the model has one, None if not. + """ + return None + + def resize_token_embeddings(self, new_num_tokens: int | None = None) -> keras.layers.Embedding | tf.Variable: + """ + Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. + + Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. + + Arguments: + new_num_tokens (`int`, *optional*): + The number of new tokens in the embedding matrix. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just + returns a pointer to the input tokens without doing anything. + + Return: + `tf.Variable` or `keras.layers.Embedding`: Pointer to the input tokens of the model. + """ + # TODO (joao): flagged for replacement (by `_v2_resized_token_embeddings`) due to embeddings refactor + + # Run the new code path if the model has a keras embeddings layer + if isinstance(self.get_input_embeddings(), keras.layers.Embedding): + return self._v2_resized_token_embeddings(new_num_tokens) + + if new_num_tokens is None or new_num_tokens == self.config.vocab_size: + return self._get_word_embedding_weight(self.get_input_embeddings()) + + model_embeds = self._resize_token_embeddings(new_num_tokens) + + # Update base model and current model config + self.config.vocab_size = new_num_tokens + + return model_embeds + + def _v2_resized_token_embeddings(self, new_num_tokens: int | None = None) -> keras.layers.Embedding: + """ + Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. + + Arguments: + new_num_tokens (`int`, *optional*): + The number of new tokens in the embedding matrix. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just + returns a pointer to the input tokens without doing anything. + + Return: + `keras.layers.Embedding`: Pointer to the input tokens of the model. + """ + if new_num_tokens is None or new_num_tokens == self.config.vocab_size: + return self.get_input_embeddings() + + model_embeds = self._v2_resize_token_embeddings(new_num_tokens) + + # Update base model and current model config + self.config.vocab_size = new_num_tokens + + return model_embeds + + def _get_word_embedding_weight(model, embedding_layer): + # TODO (joao): flagged for detection due to embeddings refactor + + # If the variable holds the weights themselves, return them + if isinstance(embedding_layer, tf.Tensor): + return embedding_layer + # Otherwise, try to get them from the layer's attributes + + embeds = getattr(embedding_layer, "weight", None) + if embeds is not None: + return embeds + + embeds = getattr(embedding_layer, "decoder", None) + if embeds is not None: + return embeds + + # The reason why the attributes don't exist might be + # because the model is not built, so retry getting + # the argument after building the model + model.build_in_name_scope() + + embeds = getattr(embedding_layer, "weight", None) + if embeds is not None: + return embeds + + embeds = getattr(embedding_layer, "decoder", None) + if embeds is not None: + return embeds + + return None + + def _resize_token_embeddings(self, new_num_tokens): + # TODO (joao): flagged for replacement (by `_v2_resize_token_embeddings`) due to embeddings refactor + old_embeddings = self._get_word_embedding_weight(self.get_input_embeddings()) + new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) + + # if word embeddings are not tied, make sure that lm head bias is resized as well + if self.get_bias() is not None: + old_lm_head_bias = self.get_bias() + new_lm_head_bias = self._get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens) + + self.set_bias(new_lm_head_bias) + + # if word embeddings are not tied, make sure that lm head decoder is resized as well + if self.get_output_embeddings() is not None: + old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings()) + new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens) + + self.set_output_embeddings(new_lm_head_decoder) + + self.set_input_embeddings(new_embeddings) + + return self.get_input_embeddings() + + def _v2_resize_token_embeddings(self, new_num_tokens): + old_embeddings = self.get_input_embeddings() + new_embeddings = self._v2_get_resized_embeddings(old_embeddings, new_num_tokens) + self.set_input_embeddings(new_embeddings) + + # If word embeddings are not tied, make sure that lm head bias is resized as well + if self.get_bias() is not None: + old_lm_head_bias = self.get_bias() + new_lm_head_bias = self._v2_get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens) + self.set_bias(new_lm_head_bias) + + # If word embeddings are not tied, make sure that lm head decoder is resized as well. + tied_weights = self.get_input_embeddings() == self.get_output_embeddings() + if self.get_output_embeddings() is not None and not tied_weights: + old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings()) + # TODO (joao): this one probably needs a v2 version with other models + new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens) + self.set_output_embeddings(new_lm_head_decoder) + + return self.get_input_embeddings() + + def _get_resized_lm_head_bias(self, old_lm_head_bias, new_num_tokens): + """ + Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end. + Reducing the size will remove vectors from the end + + Args: + old_lm_head_bias (`tf.Variable`): + Old lm head bias to be resized. + new_num_tokens (`int`, *optional*): + New number of tokens in the linear matrix. + + Increasing the size will add newly initialized vectors at the end. Reducing the size will remove + vectors from the end. If not provided or `None`, just returns None + + Return: + `tf.Variable`: Pointer to the resized bias. + """ + # TODO (joao): flagged for replacement (by `_v2_get_resized_lm_head_bias`) due to embeddings refactor + new_lm_head_bias = {} + + for attr, weight in old_lm_head_bias.items(): + first_dim, old_num_tokens = (None, shape_list(weight)[0]) if tf.rank(weight) == 1 else shape_list(weight) + size_diff = new_num_tokens - old_num_tokens + final_shape = [new_num_tokens] if first_dim is None else [first_dim, new_num_tokens] + + # initialize new bias + if tf.math.greater(size_diff, 0): + padding_shape = [[0, size_diff]] if first_dim is None else [[0, 0], [0, size_diff]] + current_bias = tf.pad(weight.value(), tf.convert_to_tensor(padding_shape), constant_values=-1) + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + mask_shape = [num_tokens_to_copy] if first_dim is None else [1, num_tokens_to_copy] + bias_mask = tf.fill(tf.convert_to_tensor(mask_shape), True) + bias_mask = tf.pad(bias_mask, tf.convert_to_tensor(padding_shape), constant_values=False) + else: + slice_from = [0] if first_dim is None else [0, 0] + current_bias = tf.slice( + weight.value(), tf.convert_to_tensor(slice_from), tf.convert_to_tensor(final_shape) + ) + bias_mask = tf.fill(tf.convert_to_tensor(final_shape), True) + + new_bias = self.add_weight( + shape=final_shape, + initializer="zeros", + trainable=True, + name=weight.name.split(":")[0], + ) + init_bias = tf.where(bias_mask, current_bias, new_bias.value()) + + new_bias.assign(init_bias) + new_lm_head_bias[attr] = new_bias + + return new_lm_head_bias + + def _v2_get_resized_lm_head_bias( + self, old_lm_head_bias: dict[str, tf.Variable], new_num_tokens: int + ) -> dict[str, tf.Tensor]: + """ + Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end. + Reducing the size will remove vectors from the end + + Args: + old_lm_head_bias (`dict[str, tf.Variable]`): + Old lm head bias to be resized. + new_num_tokens (`int`): + New number of tokens in the linear matrix. Increasing the size will add newly initialized vectors at + the end. Reducing the size will remove vectors from the end. + + Return: + `tf.Tensor`: Values for the resized bias. + """ + new_lm_head_bias = {} + + for attr, weight in old_lm_head_bias.items(): + # Determine the size difference (depending on the shape) + first_dim, old_num_tokens = (None, shape_list(weight)[0]) if tf.rank(weight) == 1 else shape_list(weight) + size_diff = new_num_tokens - old_num_tokens + + # Copy the old bias values to the new bias + if old_num_tokens > new_num_tokens: + new_bias = weight.value()[..., :new_num_tokens] + else: + padding_shape = [[0, size_diff]] if first_dim is None else [[0, 0], [0, size_diff]] + new_bias = tf.pad(weight.value(), tf.convert_to_tensor(padding_shape)) + + new_lm_head_bias[attr] = new_bias + return new_lm_head_bias + + def _get_resized_lm_head_decoder(self, old_lm_head_decoder, new_num_tokens): + """ + Build a resized decoder from the old ones. Increasing the size will add newly initialized vectors at the end. + Reducing the size will remove vectors from the end + + Args: + old_lm_head_decoder (`tf.Variable`): + Old lm head decoder to be resized. + new_num_tokens (`int`, *optional*): + New number of tokens in the linear matrix. + + Increasing the size will add newly initialized vectors at the end. Reducing the size will remove + vectors from the end. If not provided or `None`, just returns None + + Return: + `tf.Variable`: Pointer to the resized decoder or None if the output embeddings are different from the input + ones. + """ + new_lm_head_decoder = old_lm_head_decoder + is_input_output_equals = tf.reduce_any( + self._get_word_embedding_weight(self.get_input_embeddings()) == old_lm_head_decoder + ) + + if old_lm_head_decoder is not None and not is_input_output_equals: + old_embedding_dim = shape_list(old_lm_head_decoder)[1] + decoder_mask, current_decoder = init_copy_embeddings(old_lm_head_decoder, new_num_tokens) + new_lm_head_decoder = self.add_weight( + shape=(new_num_tokens, old_embedding_dim), + initializer="zeros", + trainable=True, + name=old_lm_head_decoder.name.split(":")[0], + ) + init_decoder = tf.where(decoder_mask, current_decoder, new_lm_head_decoder.value()) + + new_lm_head_decoder.assign(init_decoder) + + return new_lm_head_decoder + + def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable: + """ + Build a resized Embedding weights from a provided token Embedding weights. Increasing the size will add newly + initialized vectors at the end. Reducing the size will remove vectors from the end + + Args: + old_embeddings (`tf.Variable`): + Old embeddings to be resized. + new_num_tokens (`int`, *optional*): + New number of tokens in the embedding matrix. + + Increasing the size will add newly initialized vectors at the end. Reducing the size will remove + vectors from the end. If not provided or `None`, just returns a pointer to the input tokens + `tf.Variable` module of the model without doing anything. + + Return: + `tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if `new_num_tokens` is + `None` + """ + # TODO (joao): flagged for replacement (by `_v2_get_resized_embeddings`) due to embeddings refactor + old_embedding_dim = shape_list(old_embeddings)[1] + init_range = getattr(self.config, "initializer_range", 0.02) + embeddings_mask, current_embeddings = init_copy_embeddings(old_embeddings, new_num_tokens) + new_embeddings = self.add_weight( + name=old_embeddings.name.split(":")[0], + shape=[new_num_tokens, old_embedding_dim], + initializer=get_initializer(init_range), + dtype=tf.float32, + ) + init_embeddings = tf.where(embeddings_mask, current_embeddings, new_embeddings.value()) + + new_embeddings.assign(init_embeddings) + + return new_embeddings + + def _v2_get_resized_embeddings( + self, old_embeddings: keras.layers.Embedding, new_num_tokens: int + ) -> keras.layers.Embedding: + """ + Build a resized Embedding layer from a provided Embedding layer. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end. + + Args: + old_embeddings (`keras.layers.Embedding`): + Old embeddings to be resized. + new_num_tokens (`int`, *optional*): + New number of tokens in the embedding matrix. + + Return: + `keras.layers.Embedding`: Resized Embedding layer. + """ + + # Get the initialization range for the embeddings + init_range = 0.02 # default value + potential_initialization_variable_names = [ + "initializer_range", # most common + "initializer_factor", # e.g. T5 + "init_std", # e.g BART + ] + for var_name in potential_initialization_variable_names: + if hasattr(self.config, var_name): + init_range = getattr(self.config, var_name) + + # Get a new (initialized) embeddings layer + new_embeddings = keras.layers.Embedding( + input_dim=new_num_tokens, + output_dim=old_embeddings.output_dim, + embeddings_initializer=keras.initializers.TruncatedNormal(stddev=init_range), + name=old_embeddings.embeddings.name[:-13], # exact same scoped name except "/embeddings:0" + ) + new_embeddings(tf.constant([[0]])) + + # Copy the old embeddings to the new embeddings + if old_embeddings.input_dim >= new_num_tokens: + init_embeddings = old_embeddings.embeddings[:new_num_tokens] + else: + init_embeddings = tf.concat( + [old_embeddings.embeddings, new_embeddings.embeddings[old_embeddings.input_dim :]], axis=0 + ) + new_embeddings.embeddings.assign(init_embeddings) + return new_embeddings + + def prune_heads(self, heads_to_prune): + """ + Prunes heads of the base model. + + Arguments: + heads_to_prune (`dict[int, list[int]]`): + Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads + to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on + layer 1 and heads 2 and 3 on layer 2. + """ + raise NotImplementedError + + def save_pretrained( + self, + save_directory, + saved_model=False, + version=1, + push_to_hub=False, + signatures=None, + max_shard_size: int | str = "5GB", + create_pr: bool = False, + safe_serialization: bool = False, + token: str | bool | None = None, + **kwargs, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + [`~TFPreTrainedModel.from_pretrained`] class method. + + Arguments: + save_directory (`str`): + Directory to which to save. Will be created if it doesn't exist. + saved_model (`bool`, *optional*, defaults to `False`): + If the model has to be saved in saved model format as well or not. + version (`int`, *optional*, defaults to 1): + The version of the saved model. A saved model needs to be versioned in order to be properly loaded by + TensorFlow Serving as detailed in the official documentation + https://www.tensorflow.org/tfx/serving/serving_basic + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + signatures (`dict` or `tf.function`, *optional*): + Model's signature used for serving. This will be passed to the `signatures` argument of model.save(). + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). + + + + If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard + which will be bigger than `max_shard_size`. + + + + create_pr (`bool`, *optional*, defaults to `False`): + Whether or not to create a PR with the uploaded files or directly commit. + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save the model using `safetensors` or the traditional TensorFlow way (that uses `h5`). + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `hf auth login` (stored in `~/.huggingface`). + kwargs (`dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + use_auth_token = kwargs.pop("use_auth_token", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None: + kwargs["token"] = token + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + + if saved_model: + # If `torch_dtype` is in the config with a torch dtype class as the value, we need to change it to string. + # (Although TF doesn't care about this attribute, we can't just remove it or set it to `None`.) + if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str): + self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1] + if signatures is None: + serving_default = self.serving.get_concrete_function(self.input_signature) + if any(spec.dtype == tf.int32 for spec in self.input_signature.values()): + int64_spec = { + key: tf.TensorSpec( + shape=spec.shape, dtype=tf.int64 if spec.dtype == tf.int32 else spec.dtype, name=spec.name + ) + for key, spec in self.input_signature.items() + } + int64_serving = self.serving.get_concrete_function(int64_spec) + signatures = {"serving_default": serving_default, "int64_serving": int64_serving} + else: + signatures = serving_default + saved_model_dir = os.path.join(save_directory, "saved_model", str(version)) + self.save(saved_model_dir, include_optimizer=False, signatures=signatures) + logger.info(f"Saved model created in {saved_model_dir}") + + # Save configuration file + self.config.architectures = [self.__class__.__name__[2:]] + + # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + custom_object_save(self, save_directory, config=self.config) + + self.config.save_pretrained(save_directory) + if self.can_generate(): + self.generation_config.save_pretrained(save_directory) + + # If we save using the predefined names, we can load using `from_pretrained` + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME + output_model_file = os.path.join(save_directory, weights_name) + + shards, index = tf_shard_checkpoint(self.weights, max_shard_size, weights_name=weights_name) + + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + # If we have a shard file that is not going to be replaced, we delete it, but only from the main process + # in distributed settings to avoid race conditions. + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") + if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and filename not in shards: + os.remove(full_filename) + + if index is None: + if safe_serialization: + state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in self.weights} + safe_save_file(state_dict, output_model_file, metadata={"format": "tf"}) + else: + self.save_weights(output_model_file) + logger.info(f"Model weights saved in {output_model_file}") + else: + save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else TF2_WEIGHTS_INDEX_NAME + save_index_file = os.path.join(save_directory, save_index_file) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as index_file: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + index_file.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + for shard_file, shard in shards.items(): + if safe_serialization: + shard_state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in shard} + safe_save_file( + shard_state_dict, os.path.join(save_directory, shard_file), metadata={"format": "tf"} + ) + else: + with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file: + layers = [] + for layer in sorted(shard, key=lambda x: x.name): + if "model." in layer.name or len(layer.name.split("/")) == 1: + layer_name = layer.name + else: + layer_name = "/".join(layer.name.split("/")[1:]) + param_dset = shard_file.create_dataset( + layer_name, layer.numpy().shape, dtype=layer.numpy().dtype + ) + param_dset[:] = layer.numpy() + layers.append(layer_name.encode("utf8")) + save_attributes_to_hdf5_group(shard_file, "layer_names", layers) + + if push_to_hub: + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=token, + ) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str | os.PathLike | None, + *model_args, + config: PretrainedConfig | str | os.PathLike | None = None, + cache_dir: str | os.PathLike | None = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: str | bool | None = None, + revision: str = "main", + use_safetensors: bool | None = None, + **kwargs, + ): + r""" + Instantiate a pretrained TF 2.0 model from a pre-trained model configuration. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str`, *optional*): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this + case, `from_pt` should be set to `True` and a configuration object should be provided as `config` + argument. This loading path is slower than converting the PyTorch model in a TensorFlow model + using the provided conversion scripts and loading the TensorFlow model afterwards. + - `None` if you are both providing the configuration and state dictionary (resp. with keyword + arguments `config` and `state_dict`). + model_args (sequence of positional arguments, *optional*): + All remaining positional arguments will be passed to the underlying model's `__init__` method. + config (`Union[PretrainedConfig, str]`, *optional*): + Can be either: + + - an instance of a class derived from [`PretrainedConfig`], + - a string valid as input to [`~PretrainedConfig.from_pretrained`]. + + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the *model id* string of a pretrained + model). + - The model was saved using [`~TFPreTrainedModel.save_pretrained`] and is reloaded by supplying the + save directory. + - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a + configuration JSON file named *config.json* is found in the directory. + from_pt (`bool`, *optional*, defaults to `False`): + Load the model weights from a PyTorch state_dict save file (see docstring of + `pretrained_model_name_or_path` argument). + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): + Whether or not to raise an error if some of the weights from the checkpoint do not have the same size + as the weights of the model (if for instance, you are instantiating a model with 10 labels from a + checkpoint with 3 labels). + cache_dir (`str`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies: + (`dict[str, str], `optional`): A dictionary of proxy servers to use by protocol or endpoint, e.g., + `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): Whether ot not to also return a + dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (e.g., not try downloading the model). + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `hf auth login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + + To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. + + + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + tf_to_pt_weight_rename (`Callable`, *optional*): + A function that is called to transform the names of weights during the PyTorch to TensorFlow + crossloading process. This is not necessary for most models, but is useful to allow composite models to + be crossloaded correctly. + use_safetensors (`bool`, *optional*, defaults to `None`): + Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors` + is not installed, it will be set to `False`. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + automatically loaded: + + - If a configuration is provided with `config`, `**kwargs` will be directly passed to the + underlying model's `__init__` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, `kwargs` will be first passed to the configuration class + initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that + corresponds to a configuration attribute will be used to override said attribute with the + supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute + will be passed to the underlying model's `__init__` function. + + Examples: + + ```python + >>> from transformers import BertConfig, TFBertModel + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = TFBertModel.from_pretrained("google-bert/bert-base-uncased") + >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). + >>> model = TFBertModel.from_pretrained("./test/saved_model/") + >>> # Update configuration during loading. + >>> model = TFBertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True) + >>> assert model.config.output_attentions == True + >>> # Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable). + >>> config = BertConfig.from_json_file("./pt_model/my_pt_model_config.json") + >>> model = TFBertModel.from_pretrained("./pt_model/my_pytorch_model.bin", from_pt=True, config=config) + ```""" + from_pt = kwargs.pop("from_pt", False) + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + use_auth_token = kwargs.pop("use_auth_token", None) + trust_remote_code = kwargs.pop("trust_remote_code", None) + _ = kwargs.pop("mirror", None) + load_weight_prefix = kwargs.pop("load_weight_prefix", None) + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None) + + # Not relevant for TF models + _ = kwargs.pop("adapter_kwargs", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if trust_remote_code is True: + logger.warning( + "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" + " ignored." + ) + + user_agent = {"file_type": "model", "framework": "tensorflow", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + if use_safetensors is None and not is_safetensors_available(): + use_safetensors = False + + # Load config if we don't provide a configuration + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else pretrained_model_name_or_path + config, model_kwargs = cls.config_class.from_pretrained( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + _commit_hash=commit_hash, + **kwargs, + ) + else: + model_kwargs = kwargs + + if commit_hash is None: + commit_hash = getattr(config, "_commit_hash", None) + + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # index of the files. + is_sharded = False + # Load model + if pretrained_model_name_or_path is not None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: + if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): + # Load from a PyTorch checkpoint in priority if from_pt + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) + is_sharded = True + elif use_safetensors is not False and os.path.isfile( + os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) + elif use_safetensors is not False and os.path.isfile( + os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + ): + # Load from a sharded safetensors checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + is_sharded = True + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): + # Load from a TF 2.0 checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)): + # Load from a sharded TF 2.0 checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME) + is_sharded = True + + # At this stage we don't have a weight file so we will raise an error. + elif use_safetensors: + raise OSError( + f"Error no file named {SAFE_WEIGHTS_NAME} or {SAFE_WEIGHTS_INDEX_NAME} found in directory {pretrained_model_name_or_path}. " + f"Please make sure that the model has been saved with `safe_serialization=True` or do not " + f"set `use_safetensors=True`." + ) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile( + os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) + ): + raise OSError( + f"Error no file named {TF2_WEIGHTS_NAME} or {SAFE_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " + "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those " + "weights." + ) + else: + raise OSError( + f"Error no file named {TF2_WEIGHTS_NAME}, {SAFE_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " + f"{pretrained_model_name_or_path}." + ) + elif os.path.isfile(pretrained_model_name_or_path): + archive_file = pretrained_model_name_or_path + is_local = True + elif os.path.isfile(pretrained_model_name_or_path + ".index"): + archive_file = pretrained_model_name_or_path + ".index" + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + filename = pretrained_model_name_or_path + resolved_archive_file = download_url(pretrained_model_name_or_path) + else: + # set correct filename + if from_pt: + filename = WEIGHTS_NAME + elif use_safetensors is not False: + filename = SAFE_WEIGHTS_NAME + else: + filename = TF2_WEIGHTS_NAME + + try: + # Load from URL or cache if already cached + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "token": token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None + # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME: + # Did not find the safetensors file, let's fallback to TF. + # No support for sharded safetensors yet, so we'll raise an error if that's all we find. + filename = TF2_WEIGHTS_NAME + resolved_archive_file = cached_file( + pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **cached_file_kwargs + ) + if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME: + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME, **cached_file_kwargs + ) + if resolved_archive_file is not None: + is_sharded = True + if resolved_archive_file is None and filename == WEIGHTS_NAME: + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs + ) + if resolved_archive_file is not None: + is_sharded = True + if resolved_archive_file is None: + # Otherwise, maybe there is a PyTorch or Flax model file. We try those to give a helpful error + # message. + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + "cache_dir": cache_dir, + "local_files_only": local_files_only, + } + if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs): + is_sharded = True + elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): + raise OSError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {TF2_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to" + " load this model from those weights." + ) + else: + raise OSError( + f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}," + f" {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}" + ) + + except OSError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + except Exception: + # For any other exception, we throw a generic error. + + raise OSError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}" + ) + if is_local: + logger.info(f"loading weights file {archive_file}") + resolved_archive_file = archive_file + filename = resolved_archive_file.split(os.path.sep)[-1] + else: + logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") + else: + resolved_archive_file = None + + # We'll need to download and cache each checkpoint shard if the checkpoint is sharded. + if is_sharded: + # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path, + resolved_archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + _commit_hash=commit_hash, + ) + + safetensors_from_pt = False + if filename == SAFE_WEIGHTS_NAME: + with safe_open(resolved_archive_file, framework="tf") as f: + safetensors_metadata = f.metadata() + if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: + raise OSError( + f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata." + " Make sure you save your model with the `save_pretrained` method." + ) + safetensors_from_pt = safetensors_metadata.get("format") == "pt" + elif filename == SAFE_WEIGHTS_INDEX_NAME: + with safe_open(resolved_archive_file[0], framework="tf") as f: + safetensors_metadata = f.metadata() + if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: + raise OSError( + f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata." + " Make sure you save your model with the `save_pretrained` method." + ) + safetensors_from_pt = safetensors_metadata.get("format") == "pt" + + config.name_or_path = pretrained_model_name_or_path + + # composed models, *e.g.* TFRag, require special treatment when it comes to loading + # pre-trained weights. + if cls._requires_load_weight_prefix and model_kwargs.get("name") is not None: + model_kwargs["load_weight_prefix"] = load_weight_prefix + "/" + model_kwargs.get("name") + + # Instantiate model. + model = cls(config, *model_args, **model_kwargs) + + if tf_to_pt_weight_rename is None and hasattr(model, "tf_to_pt_weight_rename"): + # TODO Matt: This is a temporary workaround to allow weight renaming, but requires a method + # to be defined for each class that requires a rename. We can probably just have a class-level + # dict and a single top-level method or something and cut down a lot of boilerplate code + tf_to_pt_weight_rename = model.tf_to_pt_weight_rename + + if from_pt: + from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model + + # Load from a PyTorch checkpoint + return load_pytorch_checkpoint_in_tf2_model( + model, + resolved_archive_file, + allow_missing_keys=True, + output_loading_info=output_loading_info, + _prefix=load_weight_prefix, + tf_to_pt_weight_rename=tf_to_pt_weight_rename, + ) + + # we might need to extend the variable scope for composite models + if load_weight_prefix is not None: + with tf.compat.v1.variable_scope(load_weight_prefix): + model.build_in_name_scope() # build the network with dummy inputs + else: + model.build_in_name_scope() # build the network with dummy inputs + + if safetensors_from_pt and not is_sharded: + from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model + + with safe_open(resolved_archive_file, framework="tf") as safetensors_archive: + # Load from a PyTorch safetensors checkpoint + # We load in TF format here because PT weights often need to be transposed, and this is much + # faster on GPU. Loading as numpy and transposing on CPU adds several seconds to load times. + return load_pytorch_state_dict_in_tf2_model( + model, + safetensors_archive, + tf_inputs=False, # No need to build the model again + allow_missing_keys=True, + output_loading_info=output_loading_info, + _prefix=load_weight_prefix, + ignore_mismatched_sizes=ignore_mismatched_sizes, + tf_to_pt_weight_rename=tf_to_pt_weight_rename, + ) + elif safetensors_from_pt: + from .modeling_tf_pytorch_utils import load_sharded_pytorch_safetensors_in_tf2_model + + return load_sharded_pytorch_safetensors_in_tf2_model( + model, + resolved_archive_file, + tf_inputs=False, + allow_missing_keys=True, + output_loading_info=output_loading_info, + _prefix=load_weight_prefix, + ignore_mismatched_sizes=ignore_mismatched_sizes, + tf_to_pt_weight_rename=tf_to_pt_weight_rename, + ) + + # 'by_name' allow us to do transfer learning by skipping/adding layers + # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 + try: + if is_sharded: + for file in resolved_archive_file: + os.path.isfile(file), f"Error retrieving files {file}" + if filename == SAFE_WEIGHTS_INDEX_NAME: + missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights_from_safetensors( + model, + resolved_archive_file, + ignore_mismatched_sizes=ignore_mismatched_sizes, + _prefix=load_weight_prefix, + ) + else: + missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights( + model, + resolved_archive_file, + ignore_mismatched_sizes=ignore_mismatched_sizes, + _prefix=load_weight_prefix, + ) + else: + # Handles both H5 and safetensors + missing_keys, unexpected_keys, mismatched_keys = load_tf_weights( + model, + resolved_archive_file, + ignore_mismatched_sizes=ignore_mismatched_sizes, + _prefix=load_weight_prefix, + ) + except OSError as e: + try: + with open(resolved_archive_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError from e + except (UnicodeDecodeError, ValueError): + raise OSError( + "Unable to load weights from h5 file. " + "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. " + ) + + if cls._keys_to_ignore_on_load_missing is not None: + for pat in cls._keys_to_ignore_on_load_missing: + missing_keys = [k for k in missing_keys if re.search(pat, k) is None] + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some layers from the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" + " with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" + " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logger.warning(f"All model checkpoint layers were used when initializing {model.__class__.__name__}.\n") + + if len(missing_keys) > 0: + logger.warning( + f"Some layers of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.warning( + f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" + f" was trained on, you can already use {model.__class__.__name__} for predictions without further" + " training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" + " to use it for predictions and inference." + ) + + # If it is a model with generation capabilities, attempt to load the generation config + if model.can_generate(): + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + except OSError: + logger.info( + "Generation config file not found, using a generation config created from the model config." + ) + pass + + if output_loading_info: + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + } + + return model, loading_info + + return model + + def push_to_hub( + self, + repo_id: str, + use_temp_dir: bool | None = None, + commit_message: str | None = None, + private: bool | None = None, + max_shard_size: int | str | None = "10GB", + token: bool | str | None = None, + # (`use_auth_token` is deprecated: we have to keep it here as we don't have **kwargs) + use_auth_token: bool | str | None = None, + create_pr: bool = False, + **base_model_card_args, + ) -> str: + """ + Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`. + + Parameters: + repo_id (`str`): + The name of the repository you want to push your model to. It should contain your organization name + when pushing to a given organization. + use_temp_dir (`bool`, *optional*): + Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub. + Will default to `True` if there is no directory named like `repo_id`, `False` otherwise. + commit_message (`str`, *optional*): + Message to commit while pushing. Will default to `"Upload model"`. + private (`bool`, *optional*): + Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists. + token (`bool` or `str`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `hf auth login` (stored in `~/.huggingface`). Will default to `True` if `repo_url` + is not specified. + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard + will then be each of size lower than this size. If expressed as a string, needs to be digits followed + by a unit (like `"5MB"`). + create_pr (`bool`, *optional*, defaults to `False`): + Whether or not to create a PR with the uploaded files or directly commit. + + Examples: + + ```python + from transformers import TFAutoModel + + model = TFAutoModel.from_pretrained("google-bert/bert-base-cased") + + # Push the model to your namespace with the name "my-finetuned-bert". + model.push_to_hub("my-finetuned-bert") + + # Push the model to an organization with the name "my-finetuned-bert". + model.push_to_hub("huggingface/my-finetuned-bert") + ``` + """ + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if "repo_path_or_name" in base_model_card_args: + warnings.warn( + "The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use " + "`repo_id` instead." + ) + repo_id = base_model_card_args.pop("repo_path_or_name") + # Deprecation warning will be sent after for repo_url and organization + repo_url = base_model_card_args.pop("repo_url", None) + organization = base_model_card_args.pop("organization", None) + + if os.path.isdir(repo_id): + working_dir = repo_id + repo_id = repo_id.split(os.path.sep)[-1] + else: + working_dir = repo_id.split("/")[-1] + + repo_id = self._create_repo( + repo_id, private=private, token=token, repo_url=repo_url, organization=organization + ) + + if use_temp_dir is None: + use_temp_dir = not os.path.isdir(working_dir) + + with working_or_temp_dir(working_dir=working_dir, use_temp_dir=use_temp_dir) as work_dir: + files_timestamps = self._get_files_timestamps(work_dir) + + # Save all files. + self.save_pretrained(work_dir, max_shard_size=max_shard_size) + if hasattr(self, "history") and hasattr(self, "create_model_card"): + # This is a Keras model and we might be able to fish out its History and make a model card out of it + base_model_card_args = { + "output_dir": work_dir, + "model_name": Path(repo_id).name, + } + base_model_card_args.update(base_model_card_args) + self.create_model_card(**base_model_card_args) + + self._upload_modified_files( + work_dir, + repo_id, + files_timestamps, + commit_message=commit_message, + token=token, + create_pr=create_pr, + ) + + @classmethod + def register_for_auto_class(cls, auto_class="TFAutoModel"): + """ + Register this class with a given auto class. This should only be used for custom models as the ones in the + library are already mapped with an auto class. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"TFAutoModel"`): + The auto class to register this new model with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class + + +class TFConv1D(keras.layers.Layer): + """ + 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). + + Basically works like a linear layer but the weights are transposed. + + Args: + nf (`int`): + The number of output features. + nx (`int`): + The number of input features. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation to use to initialize the weights. + kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`. + """ + + def __init__(self, nf, nx, initializer_range=0.02, **kwargs): + super().__init__(**kwargs) + self.nf = nf + self.nx = nx + self.initializer_range = initializer_range + + def build(self, input_shape): + if self.built: + return + self.built = True + self.weight = self.add_weight( + "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range) + ) + self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer()) + + def call(self, x): + bz, sl = shape_list(x)[:2] + + x = tf.reshape(x, [-1, self.nx]) + x = tf.matmul(x, self.weight) + self.bias + + x = tf.reshape(x, [bz, sl, self.nf]) + + return x + + +class TFSharedEmbeddings(keras.layers.Layer): + r""" + Construct shared token embeddings. + + The weights of the embedding layer is usually shared with the weights of the linear decoder when doing language + modeling. + + Args: + vocab_size (`int`): + The size of the vocabulary, e.g., the number of unique tokens. + hidden_size (`int`): + The size of the embedding vectors. + initializer_range (`float`, *optional*): + The standard deviation to use when initializing the weights. If no value is provided, it will default to + \\(1/\sqrt{hidden\_size}\\). + kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`. + """ + + # TODO (joao): flagged for detection due to embeddings refactor + + def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float | None = None, **kwargs): + super().__init__(**kwargs) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.initializer_range = hidden_size**-0.5 if initializer_range is None else initializer_range + warnings.warn( + "`TFSharedEmbeddings` is scheduled for deletion in v4.32, use `keras.layers.Embedding` instead.", + DeprecationWarning, + ) + + def build(self, input_shape): + """ + Build shared token embedding layer Shared weights logic adapted from + https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 + """ + self.weight = self.add_weight( + "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range) + ) + super().build(input_shape) + + def get_config(self): + config = { + "vocab_size": self.vocab_size, + "hidden_size": self.hidden_size, + "initializer_range": self.initializer_range, + } + base_config = super().get_config() + + return dict(list(base_config.items()) + list(config.items())) + + def call(self, inputs: tf.Tensor, mode: str = "embedding") -> tf.Tensor: + """ + Get token embeddings of inputs or decode final hidden state. + + Args: + inputs (`tf.Tensor`): + In embedding mode, should be an int64 tensor with shape `[batch_size, length]`. + + In linear mode, should be a float tensor with shape `[batch_size, length, hidden_size]`. + mode (`str`, defaults to `"embedding"`): + A valid value is either `"embedding"` or `"linear"`, the first one indicates that the layer should be + used as an embedding layer, the second one that the layer should be used as a linear decoder. + + Returns: + `tf.Tensor`: In embedding mode, the output is a float32 embedding tensor, with shape `[batch_size, length, + embedding_size]`. + + In linear mode, the output is a float32 with shape `[batch_size, length, vocab_size]`. + + Raises: + ValueError: if `mode` is not valid. + + Shared weights logic is adapted from + [here](https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24). + """ + if mode == "embedding": + return self._embedding(inputs) + elif mode == "linear": + return self._linear(inputs) + else: + raise ValueError(f"mode {mode} is not valid.") + + def _embedding(self, input_ids): + """Applies embedding based on inputs tensor.""" + return tf.gather(self.weight, input_ids) + + def _linear(self, inputs): + """ + Computes logits by running inputs through a linear layer. + + Args: + inputs: A float32 tensor with shape [..., hidden_size] + + Returns: + float32 tensor with shape [..., vocab_size]. + """ + first_dims = shape_list(inputs)[:-1] + x = tf.reshape(inputs, [-1, self.hidden_size]) + logits = tf.matmul(x, self.weight, transpose_b=True) + + return tf.reshape(logits, first_dims + [self.vocab_size]) + + +class TFSequenceSummary(keras.layers.Layer): + """ + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`PretrainedConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + + initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation to use to initialize the weights. + kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`. + """ + + def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, **kwargs): + super().__init__(**kwargs) + + self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last" + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj + if self.has_summary: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = keras.layers.Dense( + num_classes, kernel_initializer=get_initializer(initializer_range), name="summary" + ) + + self.has_activation = False + activation_string = getattr(config, "summary_activation", None) + if activation_string is not None: + self.has_activation = True + self.activation = get_tf_activation(activation_string) + + self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0 + if self.has_first_dropout: + self.first_dropout = keras.layers.Dropout(config.summary_first_dropout) + + self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0 + if self.has_last_dropout: + self.last_dropout = keras.layers.Dropout(config.summary_last_dropout) + self.hidden_size = config.hidden_size + + def call(self, inputs, cls_index=None, training=False): + if not isinstance(inputs, (dict, tuple, list)): + hidden_states = inputs + elif isinstance(inputs, (tuple, list)): + hidden_states = inputs[0] + cls_index = inputs[1] if len(inputs) > 1 else None + assert len(inputs) <= 2, "Too many inputs." + else: + hidden_states = inputs.get("hidden_states") + cls_index = inputs.get("cls_index", None) + + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = tf.reduce_mean(hidden_states, axis=1) + elif self.summary_type == "cls_index": + hidden_shape = shape_list(hidden_states) # e.g. [batch, num choices, seq length, hidden dims] + if cls_index is None: + cls_index = tf.fill( + hidden_shape[:-2], hidden_shape[-2] - 1 + ) # A tensor full of shape [batch] or [batch, num choices] full of sequence length + cls_shape = shape_list(cls_index) + if len(cls_shape) <= len(hidden_shape) - 2: + cls_index = tf.expand_dims(cls_index, axis=-1) + # else: + # cls_index = cls_index[..., tf.newaxis] + # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2) + output = tf.squeeze( + output, axis=len(hidden_shape) - 2 + ) # shape of output: (batch, num choices, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + if self.has_first_dropout: + output = self.first_dropout(output, training=training) + + if self.has_summary: + output = self.summary(output) + + if self.has_activation: + output = self.activation(output) + + if self.has_last_dropout: + output = self.last_dropout(output, training=training) + + return output + + def build(self, input_shape): + if self.built: + return + self.built = True + if getattr(self, "summary", None) is not None: + with tf.name_scope("summary"): + self.summary.build(self.hidden_size) + + +def get_initializer(initializer_range: float = 0.02) -> keras.initializers.TruncatedNormal: + """ + Creates a `keras.initializers.TruncatedNormal` with the given range. + + Args: + initializer_range (*float*, defaults to 0.02): Standard deviation of the initializer range. + + Returns: + `keras.initializers.TruncatedNormal`: The truncated normal initializer. + """ + return keras.initializers.TruncatedNormal(stddev=initializer_range) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/optimization.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..688d0f8db56f393910dc05bb35b53213b46ced2b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/optimization.py @@ -0,0 +1,973 @@ +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch optimization for BERT model.""" + +import math +import warnings +from functools import partial +from typing import Optional, Union + +import torch +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau + +from .trainer_pt_utils import LayerWiseDummyOptimizer, LayerWiseDummyScheduler +from .trainer_utils import SchedulerType +from .utils import logging + + +logger = logging.get_logger(__name__) + + +def _get_constant_lambda(_=None): + return 1 + + +def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): + """ + Create a schedule with a constant learning rate, using the learning rate set in optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + return LambdaLR(optimizer, _get_constant_lambda, last_epoch=last_epoch) + + +def get_reduce_on_plateau_schedule(optimizer: Optimizer, **kwargs): + """ + Create a schedule with a constant learning rate that decreases when a metric has stopped improving. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + kwargs (`dict`, *optional*): + Extra parameters to be passed to the scheduler. See `torch.optim.lr_scheduler.ReduceLROnPlateau` + for possible parameters. + + Return: + `torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule. + """ + + return ReduceLROnPlateau(optimizer, **kwargs) + + +def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1.0, num_warmup_steps)) + return 1.0 + + +def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1): + """ + Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate + increases linearly between 0 and the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial(_get_constant_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps) + return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) + + +def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) + + +def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): + """ + Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after + a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_linear_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_cosine_schedule_with_warmup_lr_lambda( + current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + +def get_cosine_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_cosine_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda( + current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: int +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + if progress >= 1.0: + return 0.0 + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) + + +def get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases + linearly between 0 and the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`int`, *optional*, defaults to 1): + The number of hard restarts to use. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_polynomial_decay_schedule_with_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + lr_end: float, + power: float, + lr_init: int, +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + elif current_step > num_training_steps: + return lr_end / lr_init # as LambdaLR multiplies by lr_init + else: + lr_range = lr_init - lr_end + decay_steps = num_training_steps - num_warmup_steps + pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps + decay = lr_range * pct_remaining**power + lr_end + return decay / lr_init # as LambdaLR multiplies by lr_init + + +def get_polynomial_decay_schedule_with_warmup( + optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 +): + """ + Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the + optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + lr_end (`float`, *optional*, defaults to 1e-7): + The end LR. + power (`float`, *optional*, defaults to 1.0): + Power factor. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT + implementation at + https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + + """ + + lr_init = optimizer.defaults["lr"] + if not (lr_init > lr_end): + raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})") + + lr_lambda = partial( + _get_polynomial_decay_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + lr_end=lr_end, + power=power, + lr_init=lr_init, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: Optional[int] = None): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + shift = timescale - num_warmup_steps + decay = 1.0 / math.sqrt((current_step + shift) / timescale) + return decay + + +def get_inverse_sqrt_schedule( + optimizer: Optimizer, num_warmup_steps: int, timescale: Optional[int] = None, last_epoch: int = -1 +): + """ + Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a + warmup period which increases lr linearly from 0 to the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + timescale (`int`, *optional*, defaults to `num_warmup_steps`): + Time scale. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + # Note: this implementation is adapted from + # https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930 + + if timescale is None: + timescale = num_warmup_steps or 10_000 + + lr_lambda = partial(_get_inverse_sqrt_schedule_lr_lambda, num_warmup_steps=num_warmup_steps, timescale=timescale) + return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) + + +def _get_cosine_schedule_with_warmup_lr_lambda( + current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0 +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) + factor = factor * (1 - min_lr_rate) + min_lr_rate + return max(0, factor) + + +def get_cosine_with_min_lr_schedule_with_warmup( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1, + min_lr: Optional[float] = None, + min_lr_rate: Optional[float] = None, +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + min_lr (`float`, *optional*): + The minimum learning rate to reach after the cosine schedule. + min_lr_rate (`float`, *optional*): + The minimum learning rate as a ratio of the initial learning rate. If set, `min_lr` should not be set. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + if min_lr is not None and min_lr_rate is not None: + raise ValueError("Only one of min_lr or min_lr_rate should be set") + elif min_lr is not None: + min_lr_rate = min_lr / optimizer.defaults["lr"] + elif min_lr_rate is None: + raise ValueError("One of min_lr or min_lr_rate should be set through the `lr_scheduler_kwargs`") + + lr_lambda = partial( + _get_cosine_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + min_lr_rate=min_lr_rate, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_cosine_with_min_lr_schedule_with_warmup_lr_rate_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float, + min_lr_rate: float = 0.0, + warmup_lr_rate: Optional[float] = None, +): + current_step = float(current_step) + num_warmup_steps = float(num_warmup_steps) + num_training_steps = float(num_training_steps) + + if current_step < num_warmup_steps: + if warmup_lr_rate is None: + return (current_step + 1.0) / max(1.0, num_warmup_steps) + else: + warmup_lr_rate = float(warmup_lr_rate) + return warmup_lr_rate + (1.0 - warmup_lr_rate) * (current_step) / (max(1, num_warmup_steps - 1)) + progress = (current_step - num_warmup_steps + 1.0) / (max(1.0, num_training_steps - num_warmup_steps)) + factor = 0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress)) + factor = factor * (1 - min_lr_rate) + min_lr_rate + return max(0, factor) + + +def get_cosine_with_min_lr_schedule_with_warmup_lr_rate( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1, + min_lr: Optional[float] = None, + min_lr_rate: Optional[float] = None, + warmup_lr_rate: Optional[float] = None, +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + min_lr (`float`, *optional*): + The minimum learning rate to reach after the cosine schedule. + min_lr_rate (`float`, *optional*): + The minimum learning rate as a ratio of the initial learning rate. If set, `min_lr` should not be set. + warmup_lr_rate (`float`, *optional*): + The minimum learning rate as a ratio of the start learning rate. If not set, `warmup_lr_rate` will be treated as float(1/num_warmup_steps). + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + if min_lr is not None and min_lr_rate is not None: + raise ValueError("Only one of min_lr or min_lr_rate should be set") + elif min_lr is not None: + min_lr_rate = min_lr / optimizer.defaults["lr"] + elif min_lr_rate is None: + raise ValueError("One of min_lr or min_lr_rate should be set through the `lr_scheduler_kwargs`") + + lr_lambda = partial( + _get_cosine_with_min_lr_schedule_with_warmup_lr_rate_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + min_lr_rate=min_lr_rate, + warmup_lr_rate=warmup_lr_rate, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_wsd_scheduler_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_stable_steps: int, + num_decay_steps: int, + warmup_type: str, + decay_type: str, + min_lr_ratio: float, + num_cycles: float, +): + if current_step < num_warmup_steps: + progress = float(current_step) / float(max(1, num_warmup_steps)) + if warmup_type == "linear": + factor = progress + elif warmup_type == "cosine": + factor = 0.5 * (1.0 - math.cos(math.pi * progress)) + elif warmup_type == "1-sqrt": + factor = 1.0 - math.sqrt(1.0 - progress) + factor = factor * (1.0 - min_lr_ratio) + min_lr_ratio + return max(0.0, factor) + + if current_step < num_warmup_steps + num_stable_steps: + return 1.0 + + if current_step < num_warmup_steps + num_stable_steps + num_decay_steps: + progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps)) + if decay_type == "linear": + factor = 1.0 - progress + elif decay_type == "cosine": + factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) + elif decay_type == "1-sqrt": + factor = 1.0 - math.sqrt(progress) + factor = factor * (1.0 - min_lr_ratio) + min_lr_ratio + return max(0.0, factor) + return min_lr_ratio + + +def get_wsd_schedule( + optimizer: Optimizer, + num_warmup_steps: int, + num_decay_steps: int, + num_training_steps: Optional[int] = None, + num_stable_steps: Optional[int] = None, + warmup_type: str = "linear", + decay_type: str = "cosine", + min_lr_ratio: float = 0, + num_cycles: float = 0.5, + last_epoch: int = -1, +): + """ + Create a schedule with a learning rate that has three stages: + 1. warmup: increase from min_lr_ratio times the initial learning rate to the initial learning rate following a warmup_type. + 2. stable: constant learning rate. + 3. decay: decrease from the initial learning rate to min_lr_ratio times the initial learning rate following a decay_type. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_decay_steps (`int`): + The number of steps for the decay phase. + num_training_steps (`int`, *optional*): + The total number of training steps. This is the sum of the warmup, stable and decay steps. If `num_stable_steps` is not provided, the stable phase will be `num_training_steps - num_warmup_steps - num_decay_steps`. + num_stable_steps (`int`, *optional*): + The number of steps for the stable phase. Please ensure that `num_warmup_steps + num_stable_steps + num_decay_steps` equals `num_training_steps`, otherwise the other steps will default to the minimum learning rate. + warmup_type (`str`, *optional*, defaults to "linear"): + The type of warmup to use. Can be 'linear', 'cosine' or '1-sqrt'. + decay_type (`str`, *optional*, defaults to "cosine"): + The type of decay to use. Can be 'linear', 'cosine' or '1-sqrt'. + min_lr_ratio (`float`, *optional*, defaults to 0): + The minimum learning rate as a ratio of the initial learning rate. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + if num_training_steps is None and num_stable_steps is None: + raise ValueError("Either num_training_steps or num_stable_steps must be specified.") + + if num_training_steps is not None and num_stable_steps is not None: + warnings.warn("Both num_training_steps and num_stable_steps are specified. num_stable_steps will be used.") + + if warmup_type not in ["linear", "cosine", "1-sqrt"]: + raise ValueError(f"Unknown warmup type: {warmup_type}, expected 'linear', 'cosine' or '1-sqrt'") + + if decay_type not in ["linear", "cosine", "1-sqrt"]: + raise ValueError(f"Unknown decay type: {decay_type}, expected 'linear', 'cosine' or '1-sqrt'") + + if num_stable_steps is None: + num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps + + lr_lambda = partial( + _get_wsd_scheduler_lambda, + num_warmup_steps=num_warmup_steps, + num_stable_steps=num_stable_steps, + num_decay_steps=num_decay_steps, + warmup_type=warmup_type, + decay_type=decay_type, + min_lr_ratio=min_lr_ratio, + num_cycles=num_cycles, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +TYPE_TO_SCHEDULER_FUNCTION = { + SchedulerType.LINEAR: get_linear_schedule_with_warmup, + SchedulerType.COSINE: get_cosine_schedule_with_warmup, + SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, + SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, + SchedulerType.CONSTANT: get_constant_schedule, + SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, + SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule, + SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule, + SchedulerType.COSINE_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup, + SchedulerType.COSINE_WARMUP_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup_lr_rate, + SchedulerType.WARMUP_STABLE_DECAY: get_wsd_schedule, +} + + +def get_scheduler( + name: Union[str, SchedulerType], + optimizer: Optimizer, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, + scheduler_specific_kwargs: Optional[dict] = None, +): + """ + Unified API to get any scheduler from its name. + + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + optimizer (`torch.optim.Optimizer`): + The optimizer that will be used during training. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + scheduler_specific_kwargs (`dict`, *optional*): + Extra parameters for schedulers such as cosine with restarts. Mismatched scheduler types and scheduler + parameters will cause the scheduler function to raise a TypeError. + """ + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + + # If a `LayerWiseDummyOptimizer` is passed we extract the optimizer dict and + # recursively call `get_scheduler` to get the proper schedulers on each parameter + if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer): + optimizer_dict = optimizer.optimizer_dict + scheduler_dict = {} + + for param in optimizer_dict: + scheduler_dict[param] = get_scheduler( + name, + optimizer=optimizer_dict[param], + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + scheduler_specific_kwargs=scheduler_specific_kwargs, + ) + + def scheduler_hook(param): + # Since the optimizer hook has been already attached we only need to + # attach the scheduler hook, the gradients have been zeroed here + scheduler_dict[param].step() + + for param in optimizer_dict: + if param.requires_grad: + param.register_post_accumulate_grad_hook(scheduler_hook) + + return LayerWiseDummyScheduler(optimizer_dict=optimizer_dict, lr=optimizer.defaults["lr"]) + + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer) + + if scheduler_specific_kwargs is None: + scheduler_specific_kwargs = {} + + if name == SchedulerType.REDUCE_ON_PLATEAU: + return schedule_func(optimizer, **scheduler_specific_kwargs) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + + if name == SchedulerType.INVERSE_SQRT: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + + # wsd scheduler requires either num_training_steps or num_stable_steps + if name == SchedulerType.WARMUP_STABLE_DECAY: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **scheduler_specific_kwargs, + ) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **scheduler_specific_kwargs, + ) + + +class Adafactor(Optimizer): + """ + AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code: + https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py + + Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://huggingface.co/papers/1804.04235 Note that + this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and + `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and + `relative_step=False`. + + Arguments: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*): + The external learning rate. + eps (`tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`): + Regularization constants for square gradient and parameter scale respectively + clip_threshold (`float`, *optional*, defaults to 1.0): + Threshold of root mean square of final gradient update + decay_rate (`float`, *optional*, defaults to -0.8): + Coefficient used to compute running averages of square + beta1 (`float`, *optional*): + Coefficient used for computing running averages of gradient + weight_decay (`float`, *optional*, defaults to 0.0): + Weight decay (L2 penalty) + scale_parameter (`bool`, *optional*, defaults to `True`): + If True, learning rate is scaled by root mean square + relative_step (`bool`, *optional*, defaults to `True`): + If True, time-dependent learning rate is computed instead of external learning rate + warmup_init (`bool`, *optional*, defaults to `False`): + Time-dependent learning rate computation depends on whether warm-up initialization is being used + + This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested. + + Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3): + + - Training without LR warmup or clip_threshold is not recommended. + + - use scheduled LR warm-up to fixed LR + - use clip_threshold=1.0 (https://huggingface.co/papers/1804.04235) + - Disable relative updates + - Use scale_parameter=False + - Additional optimizer operations like gradient clipping should not be used alongside Adafactor + + Example: + + ```python + Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3) + ``` + + Others reported the following combination to work well: + + ```python + Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + ``` + + When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`] + scheduler as following: + + ```python + from transformers.optimization import Adafactor, AdafactorSchedule + + optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + lr_scheduler = AdafactorSchedule(optimizer) + trainer = Trainer(..., optimizers=(optimizer, lr_scheduler)) + ``` + + Usage: + + ```python + # replace AdamW with Adafactor + optimizer = Adafactor( + model.parameters(), + lr=1e-3, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + relative_step=False, + scale_parameter=False, + warmup_init=False, + ) + ```""" + + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + if lr is not None and relative_step: + raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") + if warmup_init and not relative_step: + raise ValueError("`warmup_init=True` requires `relative_step=True`") + + defaults = { + "lr": lr, + "eps": eps, + "clip_threshold": clip_threshold, + "decay_rate": decay_rate, + "beta1": beta1, + "weight_decay": weight_decay, + "scale_parameter": scale_parameter, + "relative_step": relative_step, + "warmup_init": warmup_init, + } + super().__init__(params, defaults) + + @staticmethod + def _get_lr(param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + # copy from fairseq's adafactor implementation: + # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + state["RMS"] = self._rms(p_data_fp32) + lr = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) + + p_data_fp32.add_(-update) + + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_data_fp32) + + return loss + + +class AdafactorSchedule(LambdaLR): + """ + Since [`~optimization.Adafactor`] performs its own scheduling, if the training loop relies on a scheduler (e.g., + for logging), this class creates a proxy object that retrieves the current lr values from the optimizer. + + It returns `initial_lr` during startup and the actual `lr` during stepping. + """ + + def __init__(self, optimizer, initial_lr=0.0): + def lr_lambda(_): + return initial_lr + + for group in optimizer.param_groups: + group["initial_lr"] = initial_lr + super().__init__(optimizer, lr_lambda) + for group in optimizer.param_groups: + del group["initial_lr"] + + def get_lr(self): + opt = self.optimizer + lrs = [ + opt._get_lr(group, opt.state[group["params"][0]]) + for group in opt.param_groups + if group["params"][0].grad is not None + ] + if len(lrs) == 0: + lrs = self.base_lrs # if called before stepping + return lrs + + +def get_adafactor_schedule(optimizer, initial_lr=0.0): + """ + Get a proxy schedule for [`~optimization.Adafactor`] + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + initial_lr (`float`, *optional*, defaults to 0.0): + Initial lr + + Return: + [`~optimization.Adafactor`] proxy schedule object. + + + """ + return AdafactorSchedule(optimizer, initial_lr) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/optimization_tf.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/optimization_tf.py new file mode 100644 index 0000000000000000000000000000000000000000..71a77251f2bf9431a08295b8daabbcbe576de71b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/optimization_tf.py @@ -0,0 +1,378 @@ +# Copyright 2019 The TensorFlow Authors, The Hugging Face Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functions and classes related to optimization (weight updates).""" + +from typing import Callable, Optional, Union + +import tensorflow as tf + + +try: + from tf_keras.optimizers.legacy import Adam +except (ImportError, ModuleNotFoundError): + from tensorflow.keras.optimizers.legacy import Adam + +from .modeling_tf_utils import keras + + +# This block because Keras loves randomly moving things to different places - this changed somewhere between 2.10 - 2.15 +if hasattr(keras.optimizers.schedules, "learning_rate_schedule"): + schedules = keras.optimizers.schedules.learning_rate_schedule +else: + schedules = keras.optimizers.schedules + + +class WarmUp(schedules.LearningRateSchedule): + """ + Applies a warmup schedule on a given learning rate decay schedule. + + Args: + initial_learning_rate (`float`): + The initial learning rate for the schedule after the warmup (so this will be the learning rate at the end + of the warmup). + decay_schedule_fn (`Callable`): + The schedule function to apply after the warmup for the rest of training. + warmup_steps (`int`): + The number of steps for the warmup part of training. + power (`float`, *optional*, defaults to 1.0): + The power to use for the polynomial warmup (defaults is a linear warmup). + name (`str`, *optional*): + Optional name prefix for the returned tensors during the schedule. + """ + + def __init__( + self, + initial_learning_rate: float, + decay_schedule_fn: Callable, + warmup_steps: int, + power: float = 1.0, + name: Optional[str] = None, + ): + super().__init__() + self.initial_learning_rate = initial_learning_rate + self.warmup_steps = warmup_steps + self.power = power + self.decay_schedule_fn = decay_schedule_fn + self.name = name + + def __call__(self, step): + with tf.name_scope(self.name or "WarmUp") as name: + # Implements polynomial warmup. i.e., if global_step < warmup_steps, the + # learning rate will be `global_step/num_warmup_steps * init_lr`. + global_step_float = tf.cast(step, tf.float32) + warmup_steps_float = tf.cast(self.warmup_steps, tf.float32) + warmup_percent_done = global_step_float / warmup_steps_float + warmup_learning_rate = self.initial_learning_rate * tf.math.pow(warmup_percent_done, self.power) + return tf.cond( + global_step_float < warmup_steps_float, + lambda: warmup_learning_rate, + lambda: self.decay_schedule_fn(step - self.warmup_steps), + name=name, + ) + + def get_config(self): + return { + "initial_learning_rate": self.initial_learning_rate, + "decay_schedule_fn": self.decay_schedule_fn, + "warmup_steps": self.warmup_steps, + "power": self.power, + "name": self.name, + } + + +def create_optimizer( + init_lr: float, + num_train_steps: int, + num_warmup_steps: int, + min_lr_ratio: float = 0.0, + adam_beta1: float = 0.9, + adam_beta2: float = 0.999, + adam_epsilon: float = 1e-8, + adam_clipnorm: Optional[float] = None, + adam_global_clipnorm: Optional[float] = None, + weight_decay_rate: float = 0.0, + power: float = 1.0, + include_in_weight_decay: Optional[list[str]] = None, +): + """ + Creates an optimizer with a learning rate schedule using a warmup phase followed by a linear decay. + + Args: + init_lr (`float`): + The desired learning rate at the end of the warmup phase. + num_train_steps (`int`): + The total number of training steps. + num_warmup_steps (`int`): + The number of warmup steps. + min_lr_ratio (`float`, *optional*, defaults to 0): + The final learning rate at the end of the linear decay will be `init_lr * min_lr_ratio`. + adam_beta1 (`float`, *optional*, defaults to 0.9): + The beta1 to use in Adam. + adam_beta2 (`float`, *optional*, defaults to 0.999): + The beta2 to use in Adam. + adam_epsilon (`float`, *optional*, defaults to 1e-8): + The epsilon to use in Adam. + adam_clipnorm (`float`, *optional*, defaults to `None`): + If not `None`, clip the gradient norm for each weight tensor to this value. + adam_global_clipnorm (`float`, *optional*, defaults to `None`) + If not `None`, clip gradient norm to this value. When using this argument, the norm is computed over all + weight tensors, as if they were concatenated into a single vector. + weight_decay_rate (`float`, *optional*, defaults to 0): + The weight decay to use. + power (`float`, *optional*, defaults to 1.0): + The power to use for PolynomialDecay. + include_in_weight_decay (`list[str]`, *optional*): + List of the parameter names (or re patterns) to apply weight decay to. If none is passed, weight decay is + applied to all parameters except bias and layer norm parameters. + """ + # Implements linear decay of the learning rate. + lr_schedule = schedules.PolynomialDecay( + initial_learning_rate=init_lr, + decay_steps=num_train_steps - num_warmup_steps, + end_learning_rate=init_lr * min_lr_ratio, + power=power, + ) + if num_warmup_steps: + lr_schedule = WarmUp( + initial_learning_rate=init_lr, + decay_schedule_fn=lr_schedule, + warmup_steps=num_warmup_steps, + ) + if weight_decay_rate > 0.0: + optimizer = AdamWeightDecay( + learning_rate=lr_schedule, + weight_decay_rate=weight_decay_rate, + beta_1=adam_beta1, + beta_2=adam_beta2, + epsilon=adam_epsilon, + clipnorm=adam_clipnorm, + global_clipnorm=adam_global_clipnorm, + exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], + include_in_weight_decay=include_in_weight_decay, + ) + else: + optimizer = keras.optimizers.Adam( + learning_rate=lr_schedule, + beta_1=adam_beta1, + beta_2=adam_beta2, + epsilon=adam_epsilon, + clipnorm=adam_clipnorm, + global_clipnorm=adam_global_clipnorm, + ) + # We return the optimizer and the LR scheduler in order to better track the + # evolution of the LR independently of the optimizer. + return optimizer, lr_schedule + + +class AdamWeightDecay(Adam): + """ + Adam enables L2 weight decay and clip_by_global_norm on gradients. Just adding the square of the weights to the + loss function is *not* the correct way of using L2 regularization/weight decay with Adam, since that will interact + with the m and v parameters in strange ways as shown in [Decoupled Weight Decay + Regularization](https://huggingface.co/papers/1711.05101). + + Instead we want to decay the weights in a manner that doesn't interact with the m/v parameters. This is equivalent + to adding the square of the weights to the loss with plain (non-momentum) SGD. + + Args: + learning_rate (`Union[float, LearningRateSchedule]`, *optional*, defaults to 0.001): + The learning rate to use or a schedule. + beta_1 (`float`, *optional*, defaults to 0.9): + The beta1 parameter in Adam, which is the exponential decay rate for the 1st momentum estimates. + beta_2 (`float`, *optional*, defaults to 0.999): + The beta2 parameter in Adam, which is the exponential decay rate for the 2nd momentum estimates. + epsilon (`float`, *optional*, defaults to 1e-07): + The epsilon parameter in Adam, which is a small constant for numerical stability. + amsgrad (`bool`, *optional*, defaults to `False`): + Whether to apply AMSGrad variant of this algorithm or not, see [On the Convergence of Adam and + Beyond](https://huggingface.co/papers/1904.09237). + weight_decay_rate (`float`, *optional*, defaults to 0.0): + The weight decay to apply. + include_in_weight_decay (`list[str]`, *optional*): + List of the parameter names (or re patterns) to apply weight decay to. If none is passed, weight decay is + applied to all parameters by default (unless they are in `exclude_from_weight_decay`). + exclude_from_weight_decay (`list[str]`, *optional*): + List of the parameter names (or re patterns) to exclude from applying weight decay to. If a + `include_in_weight_decay` is passed, the names in it will supersede this list. + name (`str`, *optional*, defaults to `"AdamWeightDecay"`): + Optional name for the operations created when applying gradients. + kwargs (`dict[str, Any]`, *optional*): + Keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by + norm; `clipvalue` is clip gradients by value, `decay` is included for backward compatibility to allow time + inverse decay of learning rate. `lr` is included for backward compatibility, recommended to use + `learning_rate` instead. + """ + + def __init__( + self, + learning_rate: Union[float, schedules.LearningRateSchedule] = 0.001, + beta_1: float = 0.9, + beta_2: float = 0.999, + epsilon: float = 1e-7, + amsgrad: bool = False, + weight_decay_rate: float = 0.0, + include_in_weight_decay: Optional[list[str]] = None, + exclude_from_weight_decay: Optional[list[str]] = None, + name: str = "AdamWeightDecay", + **kwargs, + ): + super().__init__(learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs) + self.weight_decay_rate = weight_decay_rate + self._include_in_weight_decay = include_in_weight_decay + self._exclude_from_weight_decay = exclude_from_weight_decay + + @classmethod + def from_config(cls, config): + """Creates an optimizer from its config with WarmUp custom object.""" + custom_objects = {"WarmUp": WarmUp} + return super().from_config(config, custom_objects=custom_objects) + + def _prepare_local(self, var_device, var_dtype, apply_state): + super()._prepare_local(var_device, var_dtype, apply_state) + apply_state[(var_device, var_dtype)]["weight_decay_rate"] = tf.constant( + self.weight_decay_rate, name="adam_weight_decay_rate" + ) + + def _decay_weights_op(self, var, learning_rate, apply_state): + do_decay = self._do_use_weight_decay(var.name) + if do_decay: + return var.assign_sub( + learning_rate * var * apply_state[(var.device, var.dtype.base_dtype)]["weight_decay_rate"], + use_locking=self._use_locking, + ) + return tf.no_op() + + def apply_gradients(self, grads_and_vars, name=None, **kwargs): + grads, tvars = list(zip(*grads_and_vars)) + return super().apply_gradients(zip(grads, tvars), name=name, **kwargs) + + def _get_lr(self, var_device, var_dtype, apply_state): + """Retrieves the learning rate with the given state.""" + if apply_state is None: + return self._decayed_lr_t[var_dtype], {} + + apply_state = apply_state or {} + coefficients = apply_state.get((var_device, var_dtype)) + if coefficients is None: + coefficients = self._fallback_apply_state(var_device, var_dtype) + apply_state[(var_device, var_dtype)] = coefficients + + return coefficients["lr_t"], {"apply_state": apply_state} + + def _resource_apply_dense(self, grad, var, apply_state=None): + lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state) + decay = self._decay_weights_op(var, lr_t, apply_state) + with tf.control_dependencies([decay]): + return super()._resource_apply_dense(grad, var, **kwargs) + + def _resource_apply_sparse(self, grad, var, indices, apply_state=None): + lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state) + decay = self._decay_weights_op(var, lr_t, apply_state) + with tf.control_dependencies([decay]): + return super()._resource_apply_sparse(grad, var, indices, **kwargs) + + def get_config(self): + config = super().get_config() + config.update({"weight_decay_rate": self.weight_decay_rate}) + return config + + def _do_use_weight_decay(self, param_name): + """Whether to use L2 weight decay for `param_name`.""" + if self.weight_decay_rate == 0: + return False + + if self._include_in_weight_decay: + for r in self._include_in_weight_decay: + if r in param_name: + return True + + if self._exclude_from_weight_decay: + for r in self._exclude_from_weight_decay: + if r in param_name: + return False + return True + + +# Extracted from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py +class GradientAccumulator: + """ + Gradient accumulation utility. When used with a distribution strategy, the accumulator should be called in a + replica context. Gradients will be accumulated locally on each replica and without synchronization. Users should + then call `.gradients`, scale the gradients if required, and pass the result to `apply_gradients`. + """ + + # We use the ON_READ synchronization policy so that no synchronization is + # performed on assignment. To get the value, we call .value() which returns the + # value on the current replica without synchronization. + + def __init__(self): + """Initializes the accumulator.""" + self._gradients = [] + self._accum_steps = None + + @property + def step(self): + """Number of accumulated steps.""" + if self._accum_steps is None: + self._accum_steps = tf.Variable( + tf.constant(0, dtype=tf.int64), + trainable=False, + synchronization=tf.VariableSynchronization.ON_READ, + aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, + ) + + return self._accum_steps.value() + + @property + def gradients(self): + """The accumulated gradients on the current replica.""" + if not self._gradients: + raise ValueError("The accumulator should be called first to initialize the gradients") + return [gradient.value() if gradient is not None else gradient for gradient in self._gradients] + + def __call__(self, gradients): + """Accumulates `gradients` on the current replica.""" + if not self._gradients: + _ = self.step # Create the step variable. + self._gradients.extend( + [ + tf.Variable( + tf.zeros_like(gradient), + trainable=False, + synchronization=tf.VariableSynchronization.ON_READ, + aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, + ) + if gradient is not None + else gradient + for gradient in gradients + ] + ) + if len(gradients) != len(self._gradients): + raise ValueError(f"Expected {len(self._gradients)} gradients, but got {len(gradients)}") + + for accum_gradient, gradient in zip(self._gradients, gradients): + if accum_gradient is not None and gradient is not None: + accum_gradient.assign_add(gradient) + + self._accum_steps.assign_add(1) + + def reset(self): + """Resets the accumulated gradients on the current replica.""" + if not self._gradients: + return + self._accum_steps.assign(0) + for gradient in self._gradients: + if gradient is not None: + gradient.assign(tf.zeros_like(gradient)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/processing_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/processing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8b40b6535f1b5448c8f97fa913aab224194f7121 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/processing_utils.py @@ -0,0 +1,1782 @@ +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processing saving/loading class for common processors. +""" + +import bisect +import copy +import inspect +import json +import os +import sys +import typing +import warnings +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional, TypedDict, TypeVar, Union + +import numpy as np +import typing_extensions +from huggingface_hub.errors import EntryNotFoundError + +from .audio_utils import AudioInput, load_audio +from .dynamic_module_utils import custom_object_save +from .feature_extraction_utils import BatchFeature +from .image_utils import ChannelDimension, ImageInput, is_vision_available +from .utils.chat_template_utils import render_jinja_template +from .video_utils import VideoInput, VideoMetadata + + +if is_vision_available(): + from .image_utils import PILImageResampling + + +from .tokenization_utils_base import ( + PaddingStrategy, + PreTokenizedInput, + PreTrainedTokenizerBase, + TextInput, + TruncationStrategy, +) +from .utils import ( + AUDIO_TOKENIZER_NAME, + CHAT_TEMPLATE_DIR, + CHAT_TEMPLATE_FILE, + LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE, + PROCESSOR_NAME, + PushToHubMixin, + TensorType, + cached_file, + copy_func, + direct_transformers_import, + download_url, + is_offline_mode, + is_remote_url, + is_torch_available, + list_repo_templates, + logging, +) +from .utils.deprecation import deprecate_kwarg + + +if is_torch_available(): + from .modeling_utils import PreTrainedAudioTokenizerBase + + +logger = logging.get_logger(__name__) + +# type hinting: specifying the type of processor class that inherits from ProcessorMixin +SpecificProcessorType = TypeVar("SpecificProcessorType", bound="ProcessorMixin") + +# Dynamically import the Transformers module to grab the attribute classes of the processor from their names. +transformers_module = direct_transformers_import(Path(__file__).parent) + + +AUTO_TO_BASE_CLASS_MAPPING = { + "AutoTokenizer": "PreTrainedTokenizerBase", + "AutoFeatureExtractor": "FeatureExtractionMixin", + "AutoImageProcessor": "ImageProcessingMixin", + "AutoVideoProcessor": "BaseVideoProcessor", +} + +if sys.version_info >= (3, 11): + Unpack = typing.Unpack +else: + Unpack = typing_extensions.Unpack + + +class TextKwargs(TypedDict, total=False): + """ + Keyword arguments for text processing. For extended documentation, check out tokenization_utils_base methods and + docstrings associated. + + Attributes: + add_special_tokens (`bool`, *optional*) + Whether or not to add special tokens when encoding the sequences. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*) + Activates and controls padding. + truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*): + Activates and controls truncation. + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. + stride (`int`, *optional*): + If set, the overflowing tokens will contain some tokens from the end of the truncated sequence. + is_split_into_words (`bool`, *optional*): + Whether or not the input is already pre-tokenized. + pad_to_multiple_of (`int`, *optional*): + If set, will pad the sequence to a multiple of the provided value. + return_token_type_ids (`bool`, *optional*): + Whether to return token type IDs. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. + return_overflowing_tokens (`bool`, *optional*): + Whether or not to return overflowing token sequences. + return_special_tokens_mask (`bool`, *optional*): + Whether or not to return special tokens mask information. + return_offsets_mapping (`bool`, *optional*): + Whether or not to return `(char_start, char_end)` for each token. + return_length (`bool`, *optional*): + Whether or not to return the lengths of the encoded inputs. + verbose (`bool`, *optional*): + Whether or not to print more information and warnings. + padding_side (`str`, *optional*): + The side on which padding will be applied. + return_mm_token_type_ids (`bool`, *optional*): + Whether to return multimodal token type ids indicating mm placeholder token positions. + """ + + text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] + text_target: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] + text_pair_target: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] + add_special_tokens: Optional[bool] + padding: Union[bool, str, PaddingStrategy] + truncation: Union[bool, str, TruncationStrategy] + max_length: Optional[int] + stride: Optional[int] + is_split_into_words: Optional[bool] + pad_to_multiple_of: Optional[int] + return_token_type_ids: Optional[bool] + return_attention_mask: Optional[bool] + return_overflowing_tokens: Optional[bool] + return_special_tokens_mask: Optional[bool] + return_offsets_mapping: Optional[bool] + return_length: Optional[bool] + verbose: Optional[bool] + padding_side: Optional[str] + return_mm_token_type_ids: Optional[bool] + + +class ImagesKwargs(TypedDict, total=False): + """ + Keyword arguments for image processing. For extended documentation, check the appropriate ImageProcessor + class methods and docstrings. + + Attributes: + do_resize (`bool`, *optional*): + Whether to resize the image. + size (`dict[str, int]`, *optional*): + Resize the shorter side of the input to `size["shortest_edge"]`. + crop_size (`dict[str, int]`, *optional*): + Desired output size when applying center-cropping. + resample (`PILImageResampling`, *optional*): + Resampling filter to use if resizing the image. + do_rescale (`bool`, *optional*): + Whether to rescale the image by the specified scale `rescale_factor`. + rescale_factor (`int` or `float`, *optional*): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*): + Mean to use if normalizing the image. + image_std (`float` or `list[float]`, *optional*): + Standard deviation to use if normalizing the image. + do_pad (`bool`, *optional*): + Whether to pad the image to the `(max_height, max_width)` of the images in the batch. + pad_size (`dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. + do_center_crop (`bool`, *optional*): + Whether to center crop the image. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. + device (`str`, *optional*): + The device to use for processing (e.g. "cpu", "cuda"), only relevant for fast image processing. + """ + + do_resize: Optional[bool] + size: Optional[dict[str, int]] + crop_size: Optional[dict[str, int]] + resample: Optional[Union["PILImageResampling", int]] + do_rescale: Optional[bool] + rescale_factor: Optional[float] + do_normalize: Optional[bool] + image_mean: Optional[Union[float, list[float]]] + image_std: Optional[Union[float, list[float]]] + do_pad: Optional[bool] + pad_size: Optional[dict[str, int]] + do_center_crop: Optional[bool] + data_format: Optional[ChannelDimension] + input_data_format: Optional[Union[str, ChannelDimension]] + device: Optional[str] + + +class VideosKwargs(TypedDict, total=False): + """ + Keyword arguments for video processing. + + Attributes: + do_convert_rgb (`bool`): + Whether to convert the video to RGB format. + do_resize (`bool`): + Whether to resize the video. + size (`dict[str, int]`, *optional*): + Resize the shorter side of the input to `size["shortest_edge"]`. + default_to_square (`bool`, *optional*, defaults to `self.default_to_square`): + Whether to default to a square when resizing, if size is an int. + resample (`PILImageResampling`, *optional*): + Resampling filter to use if resizing the video. + do_rescale (`bool`, *optional*): + Whether to rescale the video by the specified scale `rescale_factor`. + rescale_factor (`int` or `float`, *optional*): + Scale factor to use if rescaling the video. + do_normalize (`bool`, *optional*): + Whether to normalize the video. + image_mean (`float` or `list[float]`, *optional*): + Mean to use if normalizing the video. + image_std (`float` or `list[float]`, *optional*): + Standard deviation to use if normalizing the video. + do_center_crop (`bool`, *optional*): + Whether to center crop the video. + do_sample_frames (`bool`, *optional*): + Whether to sample frames from the video before processing or to process the whole video. + video_metadata (`Union[VideoMetadata, dict]`, *optional*): + Metadata of the video containing information about total duration, fps and total number of frames. + num_frames (`int`, *optional*): + Maximum number of frames to sample when `do_sample_frames=True`. + fps (`int` or `float`, *optional*): + Target frames to sample per second when `do_sample_frames=True`. + crop_size (`dict[str, int]`, *optional*): + Desired output size when applying center-cropping. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output video. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input video. + return_metadata (`ChannelDimension` or `str`, *optional*): + Whether to return video metadata or not. + """ + + do_convert_rgb: Optional[bool] + do_resize: Optional[bool] + size: Optional[dict[str, int]] + default_to_square: Optional[bool] + resample: Optional["PILImageResampling"] + do_rescale: Optional[bool] + rescale_factor: Optional[float] + do_normalize: Optional[bool] + image_mean: Optional[Union[float, list[float]]] + image_std: Optional[Union[float, list[float]]] + do_center_crop: Optional[bool] + crop_size: Optional[dict[str, int]] + data_format: Optional[ChannelDimension] + input_data_format: Optional[Union[str, ChannelDimension]] + device: Optional[str] + do_sample_frames: Optional[bool] + video_metadata: Optional[Union[VideoMetadata, dict]] + fps: Optional[Union[int, float]] + num_frames: Optional[int] + return_metadata: Optional[bool] + + +class AudioKwargs(TypedDict, total=False): + """ + Keyword arguments for audio processing. + + Attributes: + sampling_rate (`int`, *optional*): + The sampling rate at which the `raw_speech` input was sampled. + raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not + stereo, i.e. single float per timestep. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than *max_length* to *max_length*. + pad_to_multiple_of (`int`, *optional*): + If set, will pad the sequence to a multiple of the provided value. + return_attention_mask (`bool`, *optional*): + Whether or not [`~ASTFeatureExtractor.__call__`] should return `attention_mask`. + """ + + sampling_rate: Optional[int] + raw_speech: Optional[Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]]] + padding: Optional[Union[bool, str, PaddingStrategy]] + max_length: Optional[int] + truncation: Optional[bool] + pad_to_multiple_of: Optional[int] + return_attention_mask: Optional[bool] + + +class CommonKwargs(TypedDict, total=False): + return_tensors: Optional[Union[str, TensorType]] + + +class ProcessingKwargs(TypedDict, total=False): + """ + Base class for kwargs passing to processors. + In case a model has specific kwargs that are not present in the base class or default values for existing keys, + it should have its own `ModelProcessorKwargs` class that inherits from `ProcessingKwargs` to provide: + 1) Additional typed keys and that this model requires to process inputs. + 2) Default values for existing keys under a `_defaults` attribute. + New keys have to be defined as follows to ensure type hinting is done correctly. + + ```python + # adding a new image kwarg for this model + class ModelImagesKwargs(ImagesKwargs, total=False): + new_image_kwarg: Optional[bool] + + class ModelProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: ModelImagesKwargs + _defaults = { + "images_kwargs: { + "new_image_kwarg": False, + } + "text_kwargs": { + "padding": "max_length", + }, + } + + ``` + + For Python 3.8 compatibility, when inheriting from this class and overriding one of the kwargs, + you need to manually update the __annotations__ dictionary. This can be done as follows: + + ```python + class CustomProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: CustomImagesKwargs + + CustomProcessorKwargs.__annotations__["images_kwargs"] = CustomImagesKwargs # python 3.8 compatibility + ```python + + """ + + _defaults = {} + + common_kwargs: CommonKwargs = { + **CommonKwargs.__annotations__, + } + text_kwargs: TextKwargs = { + **TextKwargs.__annotations__, + } + images_kwargs: ImagesKwargs = { + **ImagesKwargs.__annotations__, + } + videos_kwargs: VideosKwargs = { + **VideosKwargs.__annotations__, + } + audio_kwargs: AudioKwargs = { + **AudioKwargs.__annotations__, + } + + +class TokenizerChatTemplateKwargs(TypedDict, total=False): + """ + Keyword arguments for tokenizer's `apply_chat_template`, when it is called from within a processor. + + tools (`list[Dict]`, *optional*): + A list of tools (callable functions) that will be accessible to the model. If the template does not + support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, + giving the name, description and argument types for the tool. See our + [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) + for more information. + documents (`list[dict[str, str]]`, *optional*): + A list of dicts representing documents that will be accessible to the model if it is performing RAG + (retrieval-augmented generation). If the template does not support RAG, this argument will have no + effect. We recommend that each document should be a dict containing "title" and "text" keys. Please + see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG) + for examples of passing documents with chat templates. + add_generation_prompt (bool, *optional*): + If this is set, a prompt with the token(s) that indicate + the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model. + Note that this argument will be passed to the chat template, and so it must be supported in the + template for this argument to have any effect. + continue_final_message (bool, *optional*): + If this is set, the chat will be formatted so that the final + message in the chat is open-ended, without any EOS tokens. The model will continue this message + rather than starting a new one. This allows you to "prefill" part of + the model's response for it. Cannot be used at the same time as `add_generation_prompt`. + return_assistant_tokens_mask (`bool`, defaults to `False`): + Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant, + the mask will contain 1. For user and system tokens, the mask will contain 0. + This functionality is only available for chat templates that support it via the `{% generation %}` keyword. + """ + + tools: Optional[list[dict]] = None + documents: Optional[list[dict[str, str]]] = None + add_generation_prompt: Optional[bool] = False + continue_final_message: Optional[bool] = False + return_assistant_tokens_mask: Optional[bool] = False + + +class ChatTemplateLoadKwargs(TypedDict, total=False): + """ + Keyword arguments used to load multimodal data in processor chat templates. + + num_frames (`int`, *optional*): + Number of frames to sample uniformly. If not passed, the whole video is loaded. + load_audio_from_video (`bool`, *optional*): + Whether to use the audio track of input video. If `True` the audio track will be loaded and passed to the + processor. This flag has no effect if the model doesn't support audio modality. + """ + + sampling_rate: Optional[int] = 16_000 + load_audio_from_video: Optional[bool] = False + + +class ProcessorChatTemplateKwargs(ChatTemplateLoadKwargs, TokenizerChatTemplateKwargs, total=False): + """ + Keyword arguments for processor's `apply_chat_template`. + + tokenize (`bool`, *optional*, defaults to `False`): + Whether to tokenize the output or not. + return_dict (`bool`, defaults to `False`): + Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. + """ + + tokenize: Optional[bool] = False + return_dict: Optional[bool] = False + + +class AllKwargsForChatTemplate(TypedDict, total=False): + processor_kwargs: ProcessingKwargs + mm_load_kwargs: ChatTemplateLoadKwargs + template_kwargs: ProcessorChatTemplateKwargs + + +@dataclass +class MultiModalData: + """ + Dataclass that holds extra useful data for processing + multimodal data. Processors currently cannot return keys, + unless it is used in model's forward. Thus we have helper + methods that calculate and return useful data from processing + input multimodals (images/videos). + Note that this dataclass is aimed to be used only in vLLM + and we might change its API in the future. + """ + + num_image_tokens: Optional[list[int]] = None + num_video_tokens: Optional[list[int]] = None + num_audio_tokens: Optional[list[int]] = None + num_image_patches: Optional[list[int]] = None + + def __contains__(self, key): + return hasattr(self, key) and getattr(self, key) is not None + + def __getitem__(self, key): + if hasattr(self, key): + return getattr(self, key) + raise AttributeError(f"{self.__class__.__name__} has no attribute {key}") + + +class ProcessorMixin(PushToHubMixin): + """ + This is a mixin used to provide saving/loading functionality for all processor classes. + """ + + attributes = ["feature_extractor", "tokenizer"] + optional_attributes = ["chat_template", "audio_tokenizer"] + optional_call_args: list[str] = [] + # Names need to be attr_class for attr in attributes + feature_extractor_class = None + tokenizer_class = None + _auto_class = None + valid_processor_kwargs = ProcessingKwargs + + # args have to match the attributes class attribute + def __init__(self, *args, **kwargs): + # First, extract optional attributes from kwargs if present + # Optional attributes can never be positional arguments + for optional_attribute in self.optional_attributes: + optional_attribute_value = kwargs.pop(optional_attribute, None) + setattr(self, optional_attribute, optional_attribute_value) + + # Check audio tokenizer for its class but do not treat it as attr to avoid saving weights + if optional_attribute == "audio_tokenizer" and optional_attribute_value is not None: + proper_class = self.check_argument_for_proper_class(optional_attribute, optional_attribute_value) + + if not (is_torch_available() and isinstance(optional_attribute_value, PreTrainedAudioTokenizerBase)): + raise ValueError( + f"Tried to use `{proper_class}` for audio tokenization. However, this class is not" + " registered for audio tokenization." + ) + + # Sanitize args and kwargs + for key in kwargs: + if key not in self.attributes: + raise TypeError(f"Unexpected keyword argument {key}.") + for arg, attribute_name in zip(args, self.attributes): + if attribute_name in kwargs: + raise TypeError(f"Got multiple values for argument {attribute_name}.") + else: + kwargs[attribute_name] = arg + + if len(kwargs) != len(self.attributes): + raise ValueError( + f"This processor requires {len(self.attributes)} arguments: {', '.join(self.attributes)}. Got " + f"{len(args)} arguments instead." + ) + + # Check each arg is of the proper class (this will also catch a user initializing in the wrong order) + for attribute_name, arg in kwargs.items(): + self.check_argument_for_proper_class(attribute_name, arg) + setattr(self, attribute_name, arg) + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None, + videos: Optional[VideoInput] = None, + audio: Optional[AudioInput] = None, + **kwargs: Unpack[ProcessingKwargs], + ): + """ + Main method to prepare for model inputs. This method forwards the each modality argument to its own processor + along with `kwargs`. Please refer to the docstring of the each processor attributes for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`TextInput`, `PreTokenizedInput`, `list[TextInput]`, `list[PreTokenizedInput]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The video or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + audio (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`): + The audio or batch of audio to be prepared. Each audio can be a NumPy array or PyTorch + tensor. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] object with processed inputs in a dict format. + """ + if images is None and text is None and videos is None and audio is None: + raise ValueError(f"You need to provide at least one input to call {self.__class__.__name__}") + + kwargs = self._merge_kwargs( + self.valid_processor_kwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs if hasattr(self, "tokenizer") else {}, + **kwargs, + ) + + attribute_to_kwargs = { + "tokenizer": (text, "text_kwargs"), + "image_processor": (images, "images_kwargs"), + "video_processor": (videos, "videos_kwargs"), + "feature_extractor": (audio, "audio_kwargs"), + } + outputs = {} + for attribute_name in self.attributes: + attribute = getattr(self, attribute_name, None) + input_data, input_kwargs = attribute_to_kwargs[attribute_name] + if input_data is not None and attribute is not None: + attribute_output = attribute(input_data, **kwargs[input_kwargs]) + outputs.update(attribute_output) + + return BatchFeature(outputs) + + def check_argument_for_proper_class(self, argument_name, argument): + """ + Checks the passed argument's class against the expected transformers class. In case of an unexpected + mismatch between expected and actual class, an error is raise. Otherwise, the proper retrieved class + is returned. + """ + class_name = getattr(self, f"{argument_name}_class") + # Nothing is ever going to be an instance of "AutoXxx", in that case we check the base class. + class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name) + if isinstance(class_name, tuple): + proper_class = tuple(self.get_possibly_dynamic_module(n) for n in class_name if n is not None) + else: + proper_class = self.get_possibly_dynamic_module(class_name) + + if not isinstance(argument, proper_class): + raise TypeError( + f"Received a {type(argument).__name__} for argument {argument_name}, but a {class_name} was expected." + ) + + return proper_class + + def to_dict(self, legacy_serialization=True) -> dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + + Returns: + `dict[str, Any]`: Dictionary of all the attributes that make up this processor instance. + """ + output = copy.deepcopy(self.__dict__) + + # Get the kwargs in `__init__`. + sig = inspect.signature(self.__init__) + # Only save the attributes that are presented in the kwargs of `__init__`. + attrs_to_save = list(sig.parameters) + # extra attributes to be kept + attrs_to_save += ["auto_map"] + + if legacy_serialization: + # Don't save attributes like `tokenizer`, `image processor` etc. in processor config if `legacy=True` + attrs_to_save = [x for x in attrs_to_save if x not in self.__class__.attributes] + + if "tokenizer" in output: + del output["tokenizer"] + if "qformer_tokenizer" in output: + del output["qformer_tokenizer"] + if "protein_tokenizer" in output: + del output["protein_tokenizer"] + if "chat_template" in output: + del output["chat_template"] + + def cast_array_to_list(dictionary): + """ + Numpy arrays are not serialiazable but can be in pre-processing dicts. + This function casts arrays to list, recusring through the nested configs as well. + """ + for key, value in dictionary.items(): + if isinstance(value, np.ndarray): + dictionary[key] = value.tolist() + elif isinstance(value, dict): + dictionary[key] = cast_array_to_list(value) + return dictionary + + # Serialize attributes as a dict + output = { + k: v.to_dict() if isinstance(v, PushToHubMixin) else v + for k, v in output.items() + if ( + k in attrs_to_save # keep all attributes that have to be serialized + and v.__class__.__name__ != "BeamSearchDecoderCTC" # remove attributes with that are objects + and ( + (legacy_serialization and not isinstance(v, PushToHubMixin)) or not legacy_serialization + ) # remove `PushToHubMixin` objects + ) + } + output = cast_array_to_list(output) + + # Special case, add `audio_tokenizer` dict which points to model weights and path + if not legacy_serialization and "audio_tokenizer" in output: + audio_tokenizer_dict = { + "audio_tokenizer_class": self.audio_tokenizer.__class__.__name__, + "audio_tokenizer_name_or_path": self.audio_tokenizer.name_or_path, + } + # Update or overwrite, what do audio tokenizers expect when loading? + output["audio_tokenizer"] = audio_tokenizer_dict + + output["processor_class"] = self.__class__.__name__ + + return output + + def to_json_string(self, legacy_serialization=True) -> str: + """ + Serializes this instance to a JSON string. + + Returns: + `str`: String containing all the attributes that make up this feature_extractor instance in JSON format. + """ + dictionary = self.to_dict(legacy_serialization=legacy_serialization) + + return json.dumps(dictionary, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike], legacy_serialization=True): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this processor instance's parameters will be saved. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string(legacy_serialization=legacy_serialization)) + + def __repr__(self): + attributes_repr = [f"- {name}: {repr(getattr(self, name))}" for name in self.attributes] + attributes_repr = "\n".join(attributes_repr) + return f"{self.__class__.__name__}:\n{attributes_repr}\n\n{self.to_json_string()}" + + def save_pretrained(self, save_directory, push_to_hub: bool = False, legacy_serialization: bool = True, **kwargs): + """ + Saves the attributes of this processor (feature extractor, tokenizer...) in the specified directory so that it + can be reloaded using the [`~ProcessorMixin.from_pretrained`] method. + + + + This class method is simply calling [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] and + [`~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`]. Please refer to the docstrings of the + methods above for more information. + + + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will + be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + legacy_serialization (`bool`, *optional*, defaults to `True`): + Whether or not to save processor attributes in separate config files (legacy) or in processor's config + file as a nested dict. Saving all attributes in a single dict will become the default in future versions. + Set to `legacy_serialization=True` until then. + kwargs (`dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + use_auth_token = kwargs.pop("use_auth_token", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if kwargs.get("token") is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + kwargs["token"] = use_auth_token + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + attrs = [getattr(self, attribute_name) for attribute_name in self.attributes] + configs = [(a.init_kwargs if isinstance(a, PreTrainedTokenizerBase) else a) for a in attrs] + configs.append(self) + custom_object_save(self, save_directory, config=configs) + + save_jinja_files = kwargs.get("save_jinja_files", True) + + for attribute_name in self.attributes: + # Save the tokenizer in its own vocab file. The other attributes are saved as part of `processor_config.json` + if attribute_name == "tokenizer": + attribute = getattr(self, attribute_name) + if hasattr(attribute, "_set_processor_class"): + attribute._set_processor_class(self.__class__.__name__) + + # Propagate save_jinja_files to tokenizer to ensure we don't get conflicts + attribute.save_pretrained(save_directory, save_jinja_files=save_jinja_files) + elif legacy_serialization: + attribute = getattr(self, attribute_name) + # Include the processor class in attribute config so this processor can then be reloaded with `AutoProcessor` API. + if hasattr(attribute, "_set_processor_class"): + attribute._set_processor_class(self.__class__.__name__) + attribute.save_pretrained(save_directory) + + if self._auto_class is not None: + # We added an attribute to the init_kwargs of the tokenizers, which needs to be cleaned up. + for attribute_name in self.attributes: + attribute = getattr(self, attribute_name) + if isinstance(attribute, PreTrainedTokenizerBase): + del attribute.init_kwargs["auto_map"] + + # If we save using the predefined names, we can load using `from_pretrained` + # plus we save chat_template in its own file + output_processor_file = os.path.join(save_directory, PROCESSOR_NAME) + output_chat_template_file_jinja = os.path.join(save_directory, CHAT_TEMPLATE_FILE) + output_chat_template_file_legacy = os.path.join( + save_directory, LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE + ) # Legacy filename + chat_template_dir = os.path.join(save_directory, CHAT_TEMPLATE_DIR) + + # Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict` + # to avoid serializing chat template in json config file. So let's get it from `self` directly + if self.chat_template is not None: + save_jinja_files = kwargs.get("save_jinja_files", True) + is_single_template = isinstance(self.chat_template, str) + if save_jinja_files and is_single_template: + # New format for single templates is to save them as chat_template.jinja + with open(output_chat_template_file_jinja, "w", encoding="utf-8") as f: + f.write(self.chat_template) + logger.info(f"chat template saved in {output_chat_template_file_jinja}") + elif save_jinja_files and not is_single_template: + # New format for multiple templates is to save the default as chat_template.jinja + # and the other templates in the chat_templates/ directory + for template_name, template in self.chat_template.items(): + if template_name == "default": + with open(output_chat_template_file_jinja, "w", encoding="utf-8") as f: + f.write(self.chat_template["default"]) + logger.info(f"chat template saved in {output_chat_template_file_jinja}") + else: + os.makedirs(chat_template_dir, exist_ok=True) + template_filepath = os.path.join(chat_template_dir, f"{template_name}.jinja") + with open(template_filepath, "w", encoding="utf-8") as f: + f.write(template) + logger.info(f"chat template saved in {template_filepath}") + elif is_single_template: + # Legacy format for single templates: Put them in chat_template.json + chat_template_json_string = ( + json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n" + ) + with open(output_chat_template_file_legacy, "w", encoding="utf-8") as writer: + writer.write(chat_template_json_string) + logger.info(f"chat template saved in {output_chat_template_file_legacy}") + elif self.chat_template is not None: + # At this point we have multiple templates in the legacy format, which is not supported + # chat template dicts are saved to chat_template.json as lists of dicts with fixed key names. + raise ValueError( + "Multiple chat templates are not supported in the legacy format. Please save them as " + "separate files using the `save_jinja_files` argument." + ) + + if legacy_serialization: + output_audio_tokenizer_file = os.path.join(save_directory, AUDIO_TOKENIZER_NAME) + processor_dict = self.to_dict() + + # For now, let's not save to `processor_config.json` if the processor doesn't have extra attributes and + # `auto_map` is not specified. + if set(processor_dict.keys()) != {"processor_class"}: + self.to_json_file(output_processor_file) + logger.info(f"processor saved in {output_processor_file}") + + if set(processor_dict.keys()) == {"processor_class"}: + return_files = [] + else: + return_files = [output_processor_file] + + if self.audio_tokenizer is not None: + audio_tokenizer_class = self.audio_tokenizer.__class__.__name__ + audio_tokenizer_name_or_path = self.audio_tokenizer.name_or_path + audio_tokenizer_dict = { + "audio_tokenizer_class": audio_tokenizer_class, + "audio_tokenizer_name_or_path": audio_tokenizer_name_or_path, + } + audio_tokenizer_json = json.dumps(audio_tokenizer_dict, indent=2, sort_keys=True) + "\n" + with open(output_audio_tokenizer_file, "w", encoding="utf-8") as writer: + writer.write(audio_tokenizer_json) + + # Create a unified `preprocessor_config.json` and save all attributes as a composite config, except for tokenizers + # NOTE: this will become the default way to save all processor attrbiutes in future versions. Toggled off for now to give + # us time for smoother transition + else: + self.to_json_file(output_processor_file, legacy_serialization=False) + logger.info(f"processor saved in {output_processor_file}") + return_files = [output_processor_file] + + if push_to_hub: + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("token"), + ) + + return return_files + + @classmethod + def get_processor_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> tuple[dict[str, Any], dict[str, Any]]: + """ + From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a + processor of type [`~processing_utils.ProcessingMixin`] using `from_args_and_dict`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + + Returns: + `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the processor object. + """ + # holding a copy for optionally loading the audio tokenizer (if available) + audio_tokenizer_kwargs = copy.deepcopy(kwargs) + + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", "") + + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + + user_agent = {"file_type": "processor", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + processor_file = os.path.join(pretrained_model_name_or_path, PROCESSOR_NAME) + + additional_chat_template_files = {} + resolved_additional_chat_template_files = {} + if os.path.isfile(pretrained_model_name_or_path): + resolved_processor_file = pretrained_model_name_or_path + # can't load chat-template and audio tokenizer when given a file as pretrained_model_name_or_path + resolved_chat_template_file = None + resolved_raw_chat_template_file = None + resolved_audio_tokenizer_file = None + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + processor_file = pretrained_model_name_or_path + resolved_processor_file = download_url(pretrained_model_name_or_path) + # can't load chat-template and audio tokenizer when given a file url as pretrained_model_name_or_path + resolved_chat_template_file = None + resolved_raw_chat_template_file = None + resolved_audio_tokenizer_file = None + else: + if is_local: + template_dir = Path(pretrained_model_name_or_path, CHAT_TEMPLATE_DIR) + if template_dir.is_dir(): + for template_file in template_dir.glob("*.jinja"): + template_name = template_file.stem + additional_chat_template_files[template_name] = f"{CHAT_TEMPLATE_DIR}/{template_file.name}" + else: + try: + for template in list_repo_templates( + pretrained_model_name_or_path, + local_files_only=local_files_only, + revision=revision, + cache_dir=cache_dir, + token=token, + ): + additional_chat_template_files[template] = f"{CHAT_TEMPLATE_DIR}/{template}.jinja" + except EntryNotFoundError: + pass # No template dir means no template files + processor_file = PROCESSOR_NAME + + try: + # Load from local folder or from cache or download from model Hub and cache + resolved_processor_file = cached_file( + pretrained_model_name_or_path, + processor_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + ) + + # chat_template.json is a legacy file used by the processor class + # a raw chat_template.jinja is preferred in future + resolved_chat_template_file = cached_file( + pretrained_model_name_or_path, + LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + ) + + resolved_raw_chat_template_file = cached_file( + pretrained_model_name_or_path, + CHAT_TEMPLATE_FILE, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + ) + + resolved_additional_chat_template_files = { + template_name: cached_file( + pretrained_model_name_or_path, + template_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + ) + for template_name, template_file in additional_chat_template_files.items() + } + + resolved_audio_tokenizer_file = cached_file( + pretrained_model_name_or_path, + AUDIO_TOKENIZER_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + ) + except OSError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to + # the original exception. + raise + except Exception: + # For any other exception, we throw a generic error. + raise OSError( + f"Can't load processor for '{pretrained_model_name_or_path}'. If you were trying to load" + " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a {PROCESSOR_NAME} file" + ) + + # Add chat template as kwarg before returning because most models don't have processor config + if resolved_chat_template_file is not None: + # This is the legacy path + with open(resolved_chat_template_file, encoding="utf-8") as reader: + chat_template_json = json.loads(reader.read()) + chat_templates = {"default": chat_template_json["chat_template"]} + if resolved_additional_chat_template_files: + raise ValueError( + "Cannot load chat template due to conflicting files - this checkpoint combines " + "a legacy chat_template.json file with separate template files, which is not " + "supported. To resolve this error, replace the legacy chat_template.json file " + "with a modern chat_template.jinja file." + ) + else: + chat_templates = { + template_name: open(template_file, "r", encoding="utf-8").read() + for template_name, template_file in resolved_additional_chat_template_files.items() + } + if resolved_raw_chat_template_file is not None: + with open(resolved_raw_chat_template_file, "r", encoding="utf-8") as reader: + chat_templates["default"] = reader.read() + if isinstance(chat_templates, dict) and "default" in chat_templates and len(chat_templates) == 1: + chat_templates = chat_templates["default"] # Flatten when we just have a single template/file + + if chat_templates: + kwargs["chat_template"] = chat_templates + + # Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not + # updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict. + # (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception) + # However, for models added in the future, we won't get the expected error if this file is missing. + if resolved_processor_file is None: + # In any case we need to pass `chat_template` if it is available + processor_dict = {} + else: + try: + # Load processor dict + with open(resolved_processor_file, encoding="utf-8") as reader: + text = reader.read() + processor_dict = json.loads(text) + + except json.JSONDecodeError: + raise OSError( + f"It looks like the config file at '{resolved_processor_file}' is not a valid JSON file." + ) + + if is_local: + logger.info(f"loading configuration file {resolved_processor_file}") + else: + logger.info(f"loading configuration file {processor_file} from cache at {resolved_processor_file}") + + if "chat_template" in processor_dict and processor_dict["chat_template"] is not None: + logger.warning_once( + "Chat templates should be in a 'chat_template.jinja' file but found key='chat_template' " + "in the processor's config. Make sure to move your template to its own file." + ) + + if "chat_template" in kwargs: + processor_dict["chat_template"] = kwargs.pop("chat_template") + + # Audio tokenizer needs to load the model checkpoint first, because the saved + # json file contains only references to the model path and repo id + if resolved_audio_tokenizer_file is not None or "audio_tokenizer" in processor_dict: + if resolved_audio_tokenizer_file is not None: + reader = open(resolved_audio_tokenizer_file, "r", encoding="utf-8") + audio_tokenizer_dict = reader.read() + audio_tokenizer_dict = json.loads(audio_tokenizer_dict) + else: + audio_tokenizer_dict = processor_dict["audio_tokenizer"] + + audio_tokenizer_class = cls.get_possibly_dynamic_module(audio_tokenizer_dict["audio_tokenizer_class"]) + audio_tokenizer_path = audio_tokenizer_dict["audio_tokenizer_name_or_path"] + processor_dict["audio_tokenizer"] = audio_tokenizer_class.from_pretrained( + audio_tokenizer_path, **audio_tokenizer_kwargs + ) + + # Pop attributes if saved in a single processor dict, they are loaded in `_get_arguments_from_pretrained` + for attribute in cls.attributes: + processor_dict.pop(attribute, None) + + return processor_dict, kwargs + + @classmethod + def from_args_and_dict(cls, args, processor_dict: dict[str, Any], **kwargs): + """ + Instantiates a type of [`~processing_utils.ProcessingMixin`] from a Python dictionary of parameters. + + Args: + processor_dict (`dict[str, Any]`): + Dictionary that will be used to instantiate the processor object. Such a dictionary can be + retrieved from a pretrained checkpoint by leveraging the + [`~processing_utils.ProcessingMixin.to_dict`] method. + kwargs (`dict[str, Any]`): + Additional parameters from which to initialize the processor object. + + Returns: + [`~processing_utils.ProcessingMixin`]: The processor object instantiated from those + parameters. + """ + processor_dict = processor_dict.copy() + return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + + # We have to pop up some unused (but specific) kwargs and then validate that it doesn't contain unused kwargs + # If we don't pop, some specific kwargs will raise a warning + if "processor_class" in processor_dict: + del processor_dict["processor_class"] + + if "auto_map" in processor_dict: + del processor_dict["auto_map"] + + # override processor_dict with given kwargs + processor_dict.update(kwargs) + + # check if there is an overlap between args and processor_dict + accepted_args_and_kwargs = cls.__init__.__code__.co_varnames[: cls.__init__.__code__.co_argcount][1:] + + # validate both processor_dict and given kwargs + unused_kwargs, valid_kwargs = cls.validate_init_kwargs( + processor_config=processor_dict, valid_kwargs=accepted_args_and_kwargs + ) + + # update args that are already in processor_dict to avoid duplicate arguments + args_to_update = { + i: valid_kwargs.pop(arg) + for i, arg in enumerate(accepted_args_and_kwargs) + if (arg in valid_kwargs and i < len(args)) + } + args = [args_to_update.get(i, arg) for i, arg in enumerate(args)] + + # instantiate processor with used (and valid) kwargs only + processor = cls(*args, **valid_kwargs) + + logger.info(f"Processor {processor}") + if return_unused_kwargs: + return processor, unused_kwargs + else: + return processor + + def _merge_kwargs( + self, + ModelProcessorKwargs: ProcessingKwargs, + tokenizer_init_kwargs: Optional[dict] = None, + **kwargs, + ) -> dict[str, dict]: + """ + Method to merge dictionaries of kwargs cleanly separated by modality within a Processor instance. + The order of operations is as follows: + 1) kwargs passed as before have highest priority to preserve BC. + ```python + high_priority_kwargs = {"crop_size" = {"height": 222, "width": 222}, "padding" = "max_length"} + processor(..., **high_priority_kwargs) + ``` + 2) kwargs passed as modality-specific kwargs have second priority. This is the recommended API. + ```python + processor(..., text_kwargs={"padding": "max_length"}, images_kwargs={"crop_size": {"height": 222, "width": 222}}}) + ``` + 3) kwargs passed during instantiation of a modality processor have fourth priority. + ```python + tokenizer = tokenizer_class(..., {"padding": "max_length"}) + image_processor = image_processor_class(...) + processor(tokenizer, image_processor) # will pass max_length unless overridden by kwargs at call + ``` + 4) defaults kwargs specified at processor level have lowest priority. + ```python + class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": "max_length", + "max_length": 64, + }, + } + ``` + Args: + ModelProcessorKwargs (`ProcessingKwargs`): + Typed dictionary of kwargs specifically required by the model passed. + tokenizer_init_kwargs (`Dict`, *optional*): + Dictionary of kwargs the tokenizer was instantiated with and need to take precedence over defaults. + + Returns: + output_kwargs (`Dict`): + Dictionary of per-modality kwargs to be passed to each modality-specific processor. + + """ + # Initialize dictionaries + output_kwargs = { + "text_kwargs": {}, + "images_kwargs": {}, + "audio_kwargs": {}, + "videos_kwargs": {}, + "common_kwargs": {}, + } + + default_kwargs = { + "text_kwargs": {}, + "images_kwargs": {}, + "audio_kwargs": {}, + "videos_kwargs": {}, + "common_kwargs": {}, + } + + possible_modality_keywords = {"text", "audio", "videos", "images"} + used_keys = set() + + # get defaults from set model processor kwargs if they exist + for modality in default_kwargs: + default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy() + # update defaults with arguments from tokenizer init + for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__: + # init with tokenizer init kwargs if necessary + if tokenizer_init_kwargs is not None and modality_key in tokenizer_init_kwargs: + value = ( + getattr(self.tokenizer, modality_key) + if hasattr(self.tokenizer, modality_key) + else tokenizer_init_kwargs[modality_key] + ) + default_kwargs[modality][modality_key] = value + # now defaults kwargs are updated with the tokenizers defaults. + # pass defaults to output dictionary + output_kwargs.update(default_kwargs) + + # update modality kwargs with passed kwargs + non_modality_kwargs = set(kwargs) - set(output_kwargs) + for modality, output_kwarg in output_kwargs.items(): + for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__: + # check if we received a structured kwarg dict or not to handle it correctly + if modality in kwargs: + kwarg_value = kwargs[modality].pop(modality_key, "__empty__") + # check if this key was passed as a flat kwarg. + if kwarg_value != "__empty__" and modality_key in non_modality_kwargs: + raise ValueError( + f"Keyword argument {modality_key} was passed two times:\n" + f"in a dictionary for {modality} and as a **kwarg." + ) + elif modality_key in kwargs: + # we get a modality_key instead of popping it because modality-specific processors + # can have overlapping kwargs + kwarg_value = kwargs.get(modality_key, "__empty__") + else: + kwarg_value = "__empty__" + if not isinstance(kwarg_value, str) or kwarg_value != "__empty__": + output_kwarg[modality_key] = kwarg_value + used_keys.add(modality_key) + + # Determine if kwargs is a flat dictionary or contains nested dictionaries + if any(key in default_kwargs for key in kwargs): + # kwargs is dictionary-based, and some keys match modality names + for modality, subdict in kwargs.items(): + if modality in default_kwargs: + for subkey, subvalue in subdict.items(): + if subkey not in used_keys: + output_kwargs[modality][subkey] = subvalue + used_keys.add(subkey) + else: + # kwargs is a flat dictionary + for key, kwarg in kwargs.items(): + if key not in used_keys: + if key in ModelProcessorKwargs.__annotations__["common_kwargs"].__annotations__: + output_kwargs["common_kwargs"][key] = kwarg + elif key not in possible_modality_keywords: + logger.warning_once( + f"Keyword argument `{key}` is not a valid argument for this processor and will be ignored." + ) + + # all modality-specific kwargs are updated with common kwargs + for kwarg in output_kwargs.values(): + kwarg.update(output_kwargs["common_kwargs"]) + return output_kwargs + + @classmethod + def from_pretrained( + cls: type[SpecificProcessorType], + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + **kwargs, + ) -> SpecificProcessorType: + r""" + Instantiate a processor associated with a pretrained model. + + + + This class method is simply calling the feature extractor + [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`], image processor + [`~image_processing_utils.ImageProcessingMixin`] and the tokenizer + [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`] methods. Please refer to the docstrings of the + methods above for more information. + + + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a feature extractor file saved using the + [`~SequenceFeatureExtractor.save_pretrained`] method, e.g., `./my_model_directory/`. + - a path or url to a saved feature extractor JSON *file*, e.g., + `./my_model_directory/preprocessor_config.json`. + **kwargs + Additional keyword arguments passed along to both + [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] and + [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`]. + """ + kwargs["cache_dir"] = cache_dir + kwargs["force_download"] = force_download + kwargs["local_files_only"] = local_files_only + kwargs["revision"] = revision + + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None: + kwargs["token"] = token + + args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs) + processor_dict, kwargs = cls.get_processor_dict(pretrained_model_name_or_path, **kwargs) + return cls.from_args_and_dict(args, processor_dict, **kwargs) + + @classmethod + def register_for_auto_class(cls, auto_class="AutoProcessor"): + """ + Register this class with a given auto class. This should only be used for custom feature extractors as the ones + in the library are already mapped with `AutoProcessor`. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"AutoProcessor"`): + The auto class to register this new feature extractor with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class + + @classmethod + def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + """ + Identify and instantiate the subcomponents of Processor classes, like image processors and + tokenizers. This method uses the Processor attributes like `tokenizer_class` to figure out what class those + subcomponents should be. Note that any subcomponents must either be library classes that are accessible in + the `transformers` root, or they must be custom code that has been registered with the relevant autoclass, + via methods like `AutoTokenizer.register()`. If neither of these conditions are fulfilled, this method + will be unable to find the relevant subcomponent class and will raise an error. + """ + args = [] + for attribute_name in cls.attributes: + class_name = getattr(cls, f"{attribute_name}_class") + if isinstance(class_name, tuple): + classes = tuple(cls.get_possibly_dynamic_module(n) if n is not None else None for n in class_name) + if attribute_name == "image_processor": + # TODO: @yoni, change logic in v4.52 (when use_fast set to True by default) + use_fast = kwargs.get("use_fast") + if use_fast is None: + logger.warning_once( + "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. " + "`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. " + "This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`." + ) + else: + use_fast = kwargs.get("use_fast", True) + if use_fast and classes[1] is not None: + attribute_class = classes[1] + else: + attribute_class = classes[0] + else: + attribute_class = cls.get_possibly_dynamic_module(class_name) + + args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs)) + + return args + + @staticmethod + def get_possibly_dynamic_module(module_name): + if hasattr(transformers_module, module_name): + return getattr(transformers_module, module_name) + lookup_locations = [ + transformers_module.IMAGE_PROCESSOR_MAPPING, + transformers_module.VIDEO_PROCESSOR_MAPPING, + transformers_module.TOKENIZER_MAPPING, + transformers_module.FEATURE_EXTRACTOR_MAPPING, + transformers_module.MODEL_FOR_AUDIO_TOKENIZATION_MAPPING, + ] + for lookup_location in lookup_locations: + for custom_class in lookup_location._extra_content.values(): + if isinstance(custom_class, tuple): + for custom_subclass in custom_class: + if custom_subclass is not None and custom_subclass.__name__ == module_name: + return custom_subclass + elif custom_class is not None and custom_class.__name__ == module_name: + return custom_class + raise ValueError( + f"Could not find module {module_name} in `transformers`. If this is a custom class, " + f"it should be registered using the relevant `AutoClass.register()` function so that " + f"other functions can find it!" + ) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + if not hasattr(self, "tokenizer"): + raise ValueError(f"Cannot batch decode text: {self.__class__.__name__} has no tokenizer.") + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + if not hasattr(self, "tokenizer"): + raise ValueError(f"Cannot decode text: {self.__class__.__name__} has no tokenizer.") + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + model_input_names = [] + for attribute_name in self.attributes: + attribute = getattr(self, attribute_name, None) + attr_input_names = getattr(attribute, "model_input_names") + model_input_names.extend(attr_input_names) + return model_input_names + + @staticmethod + def validate_init_kwargs(processor_config, valid_kwargs): + kwargs_from_config = set(processor_config.keys()) + valid_kwargs_set = set(valid_kwargs) + + unused_keys = kwargs_from_config - valid_kwargs_set + valid_keys = kwargs_from_config & valid_kwargs_set + + unused_kwargs = {k: processor_config[k] for k in unused_keys} if unused_keys else {} + valid_kwargs = {k: processor_config[k] for k in valid_keys} if valid_keys else {} + + return unused_kwargs, valid_kwargs + + @deprecate_kwarg("video_fps", version="4.58", new_name="fps") + @deprecate_kwarg( + "video_load_backend", + version="4.59", + additional_message=". This function will use `torchcodec` by default, or `torchvision` if `torchcodec` is not installed.", + ) + def apply_chat_template( + self, + conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]], + chat_template: Optional[str] = None, + **kwargs: Unpack[AllKwargsForChatTemplate], + ) -> str: + """ + Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input + conversations to turn them into a single tokenizable string. + + The input is expected to be in the following format, where each message content is a list consisting of text and + optionally image or video inputs. One can also provide an image, video, URL or local path which will be used to form + `pixel_values` when `return_dict=True`. If not provided, one will get only the formatted text, optionally tokenized text. + + conversation = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + {"type": "text", "text": "Please describe this image in detail."}, + ], + }, + ] + + Args: + conversation (`Union[list[Dict, [str, str]], list[list[dict[str, str]]]]`): + The conversation to format. + chat_template (`Optional[str]`, *optional*): + The Jinja template to use for formatting the conversation. If not provided, the tokenizer's + chat template is used. + """ + if chat_template is None: + if isinstance(self.chat_template, dict) and "default" in self.chat_template: + chat_template = self.chat_template["default"] + elif isinstance(self.chat_template, dict): + raise ValueError( + 'The processor has multiple chat templates but none of them are named "default". You need to specify' + " which one to use by passing the `chat_template` argument. Available templates are: " + f"{', '.join(self.chat_template.keys())}" + ) + elif self.chat_template is not None: + chat_template = self.chat_template + else: + raise ValueError( + "Cannot use apply_chat_template because this processor does not have a chat template." + ) + else: + if isinstance(self.chat_template, dict) and chat_template in self.chat_template: + # It's the name of a template, not a full template string + chat_template = self.chat_template[chat_template] + else: + # It's a template string, render it directly + pass + + is_tokenizers_fast = hasattr(self, "tokenizer") and self.tokenizer.__class__.__name__.endswith("Fast") + + if kwargs.get("continue_final_message", False): + if kwargs.get("add_generation_prompt", False): + raise ValueError( + "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead." + ) + if kwargs.get("return_assistant_tokens_mask", False): + raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.") + + if kwargs.get("return_assistant_tokens_mask", False): + if not is_tokenizers_fast: + raise ValueError( + "`return_assistant_tokens_mask` is not possible with slow tokenizers. Make sure you have `tokenizers` installed. " + "If the error persists, open an issue to support a Fast tokenizer for your model." + ) + else: + kwargs["return_offsets_mapping"] = True # force offset mapping so we can infer token boundaries + + # Fill sets of kwargs that should be used by different parts of template + processed_kwargs = { + "mm_load_kwargs": {}, + "template_kwargs": {}, + } + + for kwarg_type in processed_kwargs: + for key in AllKwargsForChatTemplate.__annotations__[kwarg_type].__annotations__: + kwarg_type_defaults = AllKwargsForChatTemplate.__annotations__[kwarg_type] + default_value = getattr(kwarg_type_defaults, key, None) + value = kwargs.pop(key, default_value) + if value is not None and not isinstance(value, dict): + processed_kwargs[kwarg_type][key] = value + + # pop unused and deprecated kwarg + kwargs.pop("video_load_backend", None) + + # Pass unprocessed custom kwargs + processed_kwargs["template_kwargs"].update(kwargs) + + if isinstance(conversation, (list, tuple)) and ( + isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content") + ): + is_batched = True + conversations = conversation + else: + is_batched = False + conversations = [conversation] + + tokenize = processed_kwargs["template_kwargs"].pop("tokenize", False) + return_dict = processed_kwargs["template_kwargs"].pop("return_dict", False) + mm_load_kwargs = processed_kwargs["mm_load_kwargs"] + + if tokenize: + batch_images, batch_videos = [], [] + batch_audios = [] + for conversation in conversations: + images, videos = [], [] + for message in conversation: + visuals = [content for content in message["content"] if content["type"] in ["image", "video"]] + audio_fnames = [ + content[key] + for content in message["content"] + for key in ["audio", "url", "path"] + if key in content and content["type"] == "audio" + ] + image_fnames = [ + vision_info[key] + for vision_info in visuals + for key in ["image", "url", "path", "base64"] + if key in vision_info and vision_info["type"] == "image" + ] + images.extend(image_fnames) + video_fnames = [ + vision_info[key] + for vision_info in visuals + for key in ["video", "url", "path"] + if key in vision_info and vision_info["type"] == "video" + ] + videos.extend(video_fnames) + + # Audio models do not accept nested list of audios (yet!) so we construct a flat input audio list + if not mm_load_kwargs["load_audio_from_video"]: + for fname in audio_fnames: + batch_audios.append(load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"])) + else: + for fname in video_fnames: + batch_audios.append(load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"])) + + # Currently all processors can accept nested list of batches, but not flat list of visuals + # So we'll make a batched list of images and let the processor handle it + batch_images.append(images) + batch_videos.append(videos) + + prompt, generation_indices = render_jinja_template( + conversations=conversations, + chat_template=chat_template, + **processed_kwargs["template_kwargs"], # different flags such as `return_assistant_mask` + **self.tokenizer.special_tokens_map, # tokenizer special tokens are used by some templates + ) + + if not is_batched: + prompt = prompt[0] + + if tokenize: + # Tokenizer's `apply_chat_template` never adds special tokens when tokenizing + # But processor's `apply_chat_template` didn't have an option to tokenize, so users had to format the prompt + # and pass it to the processor. Users thus never worried about special tokens relying on processor handling + # everything internally. The below line is to keep BC for that and be able to work with model that have + # special tokens in the template (consistent with tokenizers). We dont want to raise warning, it will flood command line + # without actionable solution for users + single_prompt = prompt[0] if is_batched else prompt + if self.tokenizer.bos_token is not None and single_prompt.startswith(self.tokenizer.bos_token): + kwargs["add_special_tokens"] = False + + # Always sample frames by default unless explicitly set to `False` by users. If users do not pass `num_frames`/`fps` + # sampling should not done for BC. + if "do_sample_frames" not in kwargs and ( + kwargs.get("fps") is not None or kwargs.get("num_frames") is not None + ): + kwargs["do_sample_frames"] = True + + images_exist = any((im is not None) for im_list in batch_images for im in im_list) + videos_exist = any((vid is not None) for vid_list in batch_videos for vid in vid_list) + out = self( + text=prompt, + images=batch_images if images_exist else None, + videos=batch_videos if videos_exist else None, + audio=batch_audios if batch_audios else None, + **kwargs, + ) + + if return_dict: + if processed_kwargs["template_kwargs"].get("return_assistant_tokens_mask", False): + assistant_masks = [] + offset_mapping = out.pop("offset_mapping") + input_ids = out["input_ids"] + for i in range(len(input_ids)): + current_mask = [0] * len(input_ids[i]) + offsets = offset_mapping[i] + offset_starts = [start for start, end in offsets] + for assistant_start_char, assistant_end_char in generation_indices[i]: + start_pos = bisect.bisect_left(offset_starts, assistant_start_char) + end_pos = bisect.bisect_left(offset_starts, assistant_end_char) + + if not ( + start_pos >= 0 + and offsets[start_pos][0] <= assistant_start_char < offsets[start_pos][1] + ): + # start_token is out of bounds maybe due to truncation. + continue + for token_id in range(start_pos, end_pos if end_pos else len(input_ids[i])): + current_mask[token_id] = 1 + assistant_masks.append(current_mask) + out["assistant_masks"] = assistant_masks + out.convert_to_tensors(tensor_type=kwargs.get("return_tensors")) + return out + else: + return out["input_ids"] + return prompt + + def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=True, **kwargs): + """ + Post-process the output of a vlm to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + skip_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. + **kwargs: + Additional arguments to be passed to the tokenizer's `batch_decode method`. + + Returns: + `list[str]`: The decoded text. + """ + return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs) + + def _check_special_mm_tokens(self, text: list[str], text_inputs: "BatchFeature", modalities: list[str]): + """ + Checks that number of special tokens in text and processed text is same. The count can be different + if tokenized text was truncated, leading to issues in model code. + """ + for modality in modalities: + token_str = getattr(self, f"{modality}_token") + token_id = getattr(self, f"{modality}_token_id") + ids_count = [list(ids).count(token_id) for ids in text_inputs["input_ids"]] + text_count = [sample.count(token_str) for sample in text] + + if ids_count != text_count: + raise ValueError( + f"Mismatch in `{modality}` token count between text and `input_ids`. Got ids={ids_count} and text={text_count}. " + "Likely due to `truncation='max_length'`. Please disable truncation or increase `max_length`." + ) + + +ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub) +if ProcessorMixin.push_to_hub.__doc__ is not None: + ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format( + object="processor", object_class="AutoProcessor", object_files="processor files" + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/py.typed b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/pytorch_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/pytorch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b1f41117d4cfeb5242a83197ea4e2b04b2d8e9a7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/pytorch_utils.py @@ -0,0 +1,380 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import inspect +from functools import lru_cache, wraps +from typing import Callable + +import torch +from safetensors.torch import storage_ptr, storage_size +from torch import nn + +from .utils import ( + is_torch_greater_or_equal, + is_torch_xla_available, + is_torch_xpu_available, + is_torchdynamo_compiling, + logging, +) + + +ALL_LAYERNORM_LAYERS = [nn.LayerNorm] + +logger = logging.get_logger(__name__) + +is_torch_greater_or_equal_than_2_8 = is_torch_greater_or_equal("2.8", accept_dev=True) +is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True) +is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True) +is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True) + +# For backwards compatibility (e.g. some remote codes on Hub using those variables). +is_torch_greater_or_equal_than_2_2 = is_torch_greater_or_equal("2.2", accept_dev=True) +is_torch_greater_or_equal_than_2_1 = is_torch_greater_or_equal("2.1", accept_dev=True) +is_torch_greater_or_equal_than_2_0 = is_torch_greater_or_equal("2.0", accept_dev=True) +is_torch_greater_or_equal_than_1_13 = is_torch_greater_or_equal("1.13", accept_dev=True) +is_torch_greater_or_equal_than_1_12 = is_torch_greater_or_equal("1.12", accept_dev=True) + +# Cache this result has it's a C FFI call which can be pretty time-consuming +_torch_distributed_available = torch.distributed.is_available() + + +def softmax_backward_data(parent, grad_output, output): + """ + A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according + to the torch version detected. + """ + + from torch import _softmax_backward_data + + return _softmax_backward_data(grad_output, output, parent.dim, output.dtype) + + +def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear: + """ + Prune a linear layer to keep only entries in index. + + Used to remove heads. + + Args: + layer (`torch.nn.Linear`): The layer to prune. + index (`torch.LongTensor`): The indices to keep in the layer. + dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices. + + Returns: + `torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`. + """ + index = index.to(layer.weight.device) + W = layer.weight.index_select(dim, index).detach().clone() + if layer.bias is not None: + if dim == 1: + b = layer.bias.detach().clone() + else: + b = layer.bias[index].detach().clone() + new_size = list(layer.weight.size()) + new_size[dim] = len(index) + new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) + new_layer.weight.requires_grad = False + new_layer.weight.copy_(W.contiguous()) + new_layer.weight.requires_grad = True + if layer.bias is not None: + new_layer.bias.requires_grad = False + new_layer.bias.copy_(b.contiguous()) + new_layer.bias.requires_grad = True + return new_layer + + +class Conv1D(nn.Module): + """ + 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). + + Basically works like a linear layer but the weights are transposed. + + Args: + nf (`int`): The number of output features. + nx (`int`): The number of input features. + """ + + def __init__(self, nf, nx): + super().__init__() + self.nf = nf + self.nx = nx + self.weight = nn.Parameter(torch.empty(nx, nf)) + self.bias = nn.Parameter(torch.zeros(nf)) + nn.init.normal_(self.weight, std=0.02) + + def __repr__(self) -> str: + return "Conv1D(nf={nf}, nx={nx})".format(**self.__dict__) + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(size_out) + return x + + +def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D: + """ + Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights + are transposed. + + Used to remove heads. + + Args: + layer ([`~pytorch_utils.Conv1D`]): The layer to prune. + index (`torch.LongTensor`): The indices to keep in the layer. + dim (`int`, *optional*, defaults to 1): The dimension on which to keep the indices. + + Returns: + [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`. + """ + index = index.to(layer.weight.device) + W = layer.weight.index_select(dim, index).detach().clone() + if dim == 0: + b = layer.bias.detach().clone() + else: + b = layer.bias[index].detach().clone() + new_size = list(layer.weight.size()) + new_size[dim] = len(index) + new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device) + new_layer.weight.requires_grad = False + new_layer.weight.copy_(W.contiguous()) + new_layer.weight.requires_grad = True + new_layer.bias.requires_grad = False + new_layer.bias.copy_(b.contiguous()) + new_layer.bias.requires_grad = True + return new_layer + + +def prune_layer(layer: nn.Linear | Conv1D, index: torch.LongTensor, dim: int | None = None) -> nn.Linear | Conv1D: + """ + Prune a Conv1D or linear layer to keep only entries in index. + + Used to remove heads. + + Args: + layer (`Union[torch.nn.Linear, Conv1D]`): The layer to prune. + index (`torch.LongTensor`): The indices to keep in the layer. + dim (`int`, *optional*): The dimension on which to keep the indices. + + Returns: + `torch.nn.Linear` or [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`. + """ + if isinstance(layer, nn.Linear): + return prune_linear_layer(layer, index, dim=0 if dim is None else dim) + elif isinstance(layer, Conv1D): + return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) + else: + raise ValueError(f"Can't prune layer of class {layer.__class__}") + + +def apply_chunking_to_forward( + forward_fn: Callable[..., torch.Tensor], + chunk_size: int, + chunk_dim: int, + *input_tensors, +) -> torch.Tensor: + """ + This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension + `chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory. + + If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly + applying `forward_fn` to `input_tensors`. + + Args: + forward_fn (`Callable[..., torch.Tensor]`): + The forward function of the model. + chunk_size (`int`): + The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`. + chunk_dim (`int`): + The dimension over which the `input_tensors` should be chunked. + input_tensors (`tuple[torch.Tensor]`): + The input tensors of `forward_fn` which will be chunked + + Returns: + `torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`. + + + Examples: + + ```python + # rename the usual forward() fn to forward_chunk() + def forward_chunk(self, hidden_states): + hidden_states = self.decoder(hidden_states) + return hidden_states + + + # implement a chunked forward function + def forward(self, hidden_states): + return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states) + ```""" + + assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors" + + # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility + num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters) + if num_args_in_forward_chunk_fn != len(input_tensors): + raise ValueError( + f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input " + "tensors are given" + ) + + if chunk_size > 0: + tensor_shape = input_tensors[0].shape[chunk_dim] + for input_tensor in input_tensors: + if input_tensor.shape[chunk_dim] != tensor_shape: + raise ValueError( + f"All input tenors have to be of the same shape: {tensor_shape}, " + f"found shape {input_tensor.shape[chunk_dim]}" + ) + + if input_tensors[0].shape[chunk_dim] % chunk_size != 0: + raise ValueError( + f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk " + f"size {chunk_size}" + ) + + num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size + + # chunk input tensor into tuples + input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors) + # apply forward fn to every tuple + output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks)) + # concatenate output at same dimension + return torch.cat(output_chunks, dim=chunk_dim) + + return forward_fn(*input_tensors) + + +def find_pruneable_heads_and_indices( + heads: list[int], n_heads: int, head_size: int, already_pruned_heads: set[int] +) -> tuple[set[int], torch.LongTensor]: + """ + Finds the heads and their indices taking `already_pruned_heads` into account. + + Args: + heads (`list[int]`): List of the indices of heads to prune. + n_heads (`int`): The number of heads in the model. + head_size (`int`): The size of each head. + already_pruned_heads (`Set[int]`): A set of already pruned heads. + + Returns: + `tuple[Set[int], torch.LongTensor]`: A tuple with the indices of heads to prune taking `already_pruned_heads` + into account and the indices of rows/columns to keep in the layer weight. + """ + mask = torch.ones(n_heads, head_size) + heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads + for head in heads: + # Compute how many pruned heads are before the head and move the index accordingly + head = head - sum(1 if h < head else 0 for h in already_pruned_heads) + mask[head] = 0 + mask = mask.view(-1).contiguous().eq(1) + index: torch.LongTensor = torch.arange(len(mask))[mask].long() + return heads, index + + +def meshgrid(*tensors: torch.Tensor | list[torch.Tensor], indexing: str | None = None) -> tuple[torch.Tensor, ...]: + """ + Wrapper around torch.meshgrid to avoid warning messages about the introduced `indexing` argument. + + Reference: https://pytorch.org/docs/1.13/generated/torch.meshgrid.html + """ + return torch.meshgrid(*tensors, indexing=indexing) + + +def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]: + """ + Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For + example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is + guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with + non-overlapping lifetimes may have the same id. + """ + if _torch_distributed_available and is_torch_greater_or_equal("2.5"): + from torch.distributed.tensor import DTensor + + if isinstance(tensor, DTensor): + local_tensor = tensor.to_local() + return tensor.device, local_tensor.storage().data_ptr(), tensor.nbytes + + if tensor.device.type == "xla" and is_torch_xla_available(): + # NOTE: xla tensors dont have storage + # use some other unique id to distinguish. + # this is a XLA tensor, it must be created using torch_xla's + # device. So the following import is safe: + import torch_xla + + unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor) + else: + unique_id = storage_ptr(tensor) + + return tensor.device, unique_id, storage_size(tensor) + + +def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int) -> torch.Tensor: + """ + Same as `torch.isin` without flags, but MPS-friendly. We can remove this function when we stop supporting + torch <= 2.3. See https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075 + + Args: + elements (`torch.Tensor`): Input elements + test_elements (`torch.Tensor` or `int`): The elements to check against. + + Returns: + `torch.Tensor`: A boolean tensor of the same shape as `elements` that is True for `elements` in `test_elements` + and False otherwise + """ + + if elements.device.type == "mps" and not is_torch_greater_or_equal_than_2_4: + test_elements = torch.tensor(test_elements) + if test_elements.ndim == 0: + test_elements = test_elements.unsqueeze(0) + return elements.tile(test_elements.shape[0], 1).eq(test_elements.unsqueeze(1)).sum(dim=0).bool().squeeze() + else: + # Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045 + return torch.isin(elements, test_elements) + + +@wraps(lru_cache) +def compile_compatible_method_lru_cache(*lru_args, **lru_kwargs): + """ + LRU cache decorator from standard functools library, but with a workaround to disable + caching when torchdynamo is compiling. Expected to work with class methods. + """ + + def decorator(func): + func_with_cache = lru_cache(*lru_args, **lru_kwargs)(func) + + @wraps(func) + def wrapper(*args, **kwargs): + if is_torchdynamo_compiling(): + return func(*args, **kwargs) + else: + return func_with_cache(*args, **kwargs) + + return wrapper + + return decorator + + +def infer_device(): + """ + Infers available device. + """ + torch_device = "cpu" + if torch.cuda.is_available(): + torch_device = "cuda" + elif is_torch_xpu_available(): + torch_device = "xpu" + + return torch_device diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/safetensors_conversion.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/safetensors_conversion.py new file mode 100644 index 0000000000000000000000000000000000000000..f1612d3ea57c98fd1d383887cfbeb4e2882d3963 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/safetensors_conversion.py @@ -0,0 +1,105 @@ +from typing import Optional + +import requests +from huggingface_hub import Discussion, HfApi, get_repo_discussions + +from .utils import cached_file, http_user_agent, logging + + +logger = logging.get_logger(__name__) + + +def previous_pr(api: HfApi, model_id: str, pr_title: str, token: str) -> Optional["Discussion"]: + main_commit = api.list_repo_commits(model_id, token=token)[0].commit_id + for discussion in get_repo_discussions(repo_id=model_id, token=token): + if discussion.title == pr_title and discussion.status == "open" and discussion.is_pull_request: + commits = api.list_repo_commits(model_id, revision=discussion.git_reference, token=token) + + if main_commit == commits[1].commit_id: + return discussion + return None + + +def spawn_conversion(token: str, private: bool, model_id: str): + logger.info("Attempting to convert .bin model on the fly to safetensors.") + + safetensors_convert_space_url = "https://safetensors-convert.hf.space" + sse_url = f"{safetensors_convert_space_url}/call/run" + + def start(_sse_connection): + for line in _sse_connection.iter_lines(): + line = line.decode() + if line.startswith("event:"): + status = line[7:] + logger.debug(f"Safetensors conversion status: {status}") + + if status == "complete": + return + elif status == "heartbeat": + logger.debug("Heartbeat") + else: + logger.debug(f"Unknown status {status}") + else: + logger.debug(line) + + data = {"data": [model_id, private, token]} + + result = requests.post(sse_url, stream=True, json=data).json() + event_id = result["event_id"] + + with requests.get(f"{sse_url}/{event_id}", stream=True) as sse_connection: + try: + logger.debug("Spawning safetensors automatic conversion.") + start(sse_connection) + except Exception as e: + logger.warning(f"Error during conversion: {repr(e)}") + + +def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs): + private = api.model_info(model_id).private + + logger.info("Attempting to create safetensors variant") + pr_title = "Adding `safetensors` variant of this model" + token = kwargs.get("token") + + # This looks into the current repo's open PRs to see if a PR for safetensors was already open. If so, it + # returns it. It checks that the PR was opened by the bot and not by another user so as to prevent + # security breaches. + pr = previous_pr(api, model_id, pr_title, token=token) + + if pr is None or (not private and pr.author != "SFconvertbot"): + spawn_conversion(token, private, model_id) + pr = previous_pr(api, model_id, pr_title, token=token) + else: + logger.info("Safetensors PR exists") + + sha = f"refs/pr/{pr.num}" + + return sha + + +def auto_conversion(pretrained_model_name_or_path: str, ignore_errors_during_conversion=False, **cached_file_kwargs): + try: + api = HfApi(token=cached_file_kwargs.get("token"), headers={"user-agent": http_user_agent()}) + sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs) + + if sha is None: + return None, None + cached_file_kwargs["revision"] = sha + del cached_file_kwargs["_commit_hash"] + + # This is an additional HEAD call that could be removed if we could infer sharded/non-sharded from the PR + # description. + sharded = api.file_exists( + pretrained_model_name_or_path, + "model.safetensors.index.json", + revision=sha, + token=cached_file_kwargs.get("token"), + ) + filename = "model.safetensors.index.json" if sharded else "model.safetensors" + + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + return resolved_archive_file, sha, sharded + except Exception as e: + if not ignore_errors_during_conversion: + raise e diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/testing_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/testing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..499716789b9bc5197b2669bb4099c13741f4fd79 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/testing_utils.py @@ -0,0 +1,4154 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import collections +import contextlib +import copy +import doctest +import functools +import gc +import importlib +import inspect +import logging +import multiprocessing +import os +import re +import shlex +import shutil +import subprocess +import sys +import tempfile +import threading +import time +import traceback +import types +import unittest +from collections import UserDict, defaultdict +from collections.abc import Generator, Iterable, Iterator, Mapping +from dataclasses import MISSING, fields +from functools import cache, wraps +from io import StringIO +from pathlib import Path +from typing import Any, Callable, Optional, Union +from unittest import mock +from unittest.mock import patch + +import huggingface_hub.utils +import requests +import urllib3 +from huggingface_hub import delete_repo +from packaging import version + +from transformers import Trainer +from transformers import logging as transformers_logging + +from .integrations import ( + is_clearml_available, + is_optuna_available, + is_ray_available, + is_sigopt_available, + is_swanlab_available, + is_tensorboard_available, + is_trackio_available, + is_wandb_available, +) +from .integrations.deepspeed import is_deepspeed_available +from .utils import ( + ACCELERATE_MIN_VERSION, + GGUF_MIN_VERSION, + TRITON_MIN_VERSION, + is_accelerate_available, + is_apex_available, + is_apollo_torch_available, + is_aqlm_available, + is_auto_awq_available, + is_auto_gptq_available, + is_auto_round_available, + is_av_available, + is_bitsandbytes_available, + is_bitsandbytes_multi_backend_available, + is_bs4_available, + is_compressed_tensors_available, + is_cv2_available, + is_cython_available, + is_decord_available, + is_detectron2_available, + is_eetq_available, + is_essentia_available, + is_faiss_available, + is_fbgemm_gpu_available, + is_flash_attn_2_available, + is_flash_attn_3_available, + is_flax_available, + is_flute_available, + is_fp_quant_available, + is_fsdp_available, + is_ftfy_available, + is_g2p_en_available, + is_galore_torch_available, + is_gguf_available, + is_gptqmodel_available, + is_grokadamw_available, + is_hadamard_available, + is_hqq_available, + is_huggingface_hub_greater_or_equal, + is_ipex_available, + is_jinja_available, + is_jumanpp_available, + is_keras_nlp_available, + is_kernels_available, + is_levenshtein_available, + is_librosa_available, + is_liger_kernel_available, + is_lomo_available, + is_mistral_common_available, + is_natten_available, + is_nltk_available, + is_onnx_available, + is_openai_available, + is_optimum_available, + is_optimum_quanto_available, + is_pandas_available, + is_peft_available, + is_phonemizer_available, + is_pretty_midi_available, + is_psutil_available, + is_pyctcdecode_available, + is_pytesseract_available, + is_pytest_available, + is_pytorch_quantization_available, + is_quark_available, + is_qutlass_available, + is_rjieba_available, + is_sacremoses_available, + is_safetensors_available, + is_schedulefree_available, + is_scipy_available, + is_sentencepiece_available, + is_seqio_available, + is_spacy_available, + is_speech_available, + is_spqr_available, + is_sudachi_available, + is_sudachi_projection_available, + is_tf_available, + is_tiktoken_available, + is_timm_available, + is_tokenizers_available, + is_torch_available, + is_torch_bf16_available_on_device, + is_torch_bf16_gpu_available, + is_torch_fp16_available_on_device, + is_torch_greater_or_equal, + is_torch_hpu_available, + is_torch_mlu_available, + is_torch_neuroncore_available, + is_torch_npu_available, + is_torch_optimi_available, + is_torch_tensorrt_fx_available, + is_torch_tf32_available, + is_torch_xla_available, + is_torch_xpu_available, + is_torchao_available, + is_torchaudio_available, + is_torchcodec_available, + is_torchdynamo_available, + is_torchvision_available, + is_triton_available, + is_vision_available, + is_vptq_available, + strtobool, +) + + +if is_accelerate_available(): + from accelerate.state import AcceleratorState, PartialState + from accelerate.utils.imports import is_fp8_available + + +if is_pytest_available(): + from _pytest.doctest import ( + Module, + _get_checker, + _get_continue_on_failure, + _get_runner, + _is_mocked, + _patch_unwrap_mock_aware, + get_optionflags, + ) + from _pytest.outcomes import skip + from _pytest.pathlib import import_path + from pytest import DoctestItem +else: + Module = object + DoctestItem = object + + +SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" +DUMMY_UNKNOWN_IDENTIFIER = "julien-c/dummy-unknown" +DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer" +# Used to test Auto{Config, Model, Tokenizer} model_type detection. + +# Used to test the hub +USER = "__DUMMY_TRANSFORMERS_USER__" +ENDPOINT_STAGING = "https://hub-ci.huggingface.co" + +# Not critical, only usable on the sandboxed CI instance. +TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL" + + +# Used in CausalLMModelTester (and related classes/methods) to infer the common model classes from the base model class +_COMMON_MODEL_NAMES_MAP = { + "config_class": "Config", + "causal_lm_class": "ForCausalLM", + "question_answering_class": "ForQuestionAnswering", + "sequence_classification_class": "ForSequenceClassification", + "token_classification_class": "ForTokenClassification", +} + + +if is_torch_available(): + import torch + + IS_ROCM_SYSTEM = torch.version.hip is not None + IS_CUDA_SYSTEM = torch.version.cuda is not None + IS_XPU_SYSTEM = getattr(torch.version, "xpu", None) is not None +else: + IS_ROCM_SYSTEM = False + IS_CUDA_SYSTEM = False + IS_XPU_SYSTEM = False + +logger = transformers_logging.get_logger(__name__) + + +def parse_flag_from_env(key, default=False): + try: + value = os.environ[key] + except KeyError: + # KEY isn't set, default to `default`. + _value = default + else: + # KEY is set, convert it to True or False. + try: + _value = strtobool(value) + except ValueError: + # More values are supported, but let's keep the message simple. + raise ValueError(f"If set, {key} must be yes or no.") + return _value + + +def parse_int_from_env(key, default=None): + try: + value = os.environ[key] + except KeyError: + _value = default + else: + try: + _value = int(value) + except ValueError: + raise ValueError(f"If set, {key} must be a int.") + return _value + + +_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) +_run_flaky_tests = parse_flag_from_env("RUN_FLAKY", default=True) +_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False) +_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False) +_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True) +_run_agent_tests = parse_flag_from_env("RUN_AGENT_TESTS", default=False) + + +def is_staging_test(test_case): + """ + Decorator marking a test as a staging test. + + Those tests will run using the staging environment of huggingface.co instead of the real model hub. + """ + if not _run_staging: + return unittest.skip(reason="test is staging test")(test_case) + else: + try: + import pytest # We don't need a hard dependency on pytest in the main library + except ImportError: + return test_case + else: + return pytest.mark.is_staging_test()(test_case) + + +def is_pipeline_test(test_case): + """ + Decorator marking a test as a pipeline test. If RUN_PIPELINE_TESTS is set to a falsy value, those tests will be + skipped. + """ + if not _run_pipeline_tests: + return unittest.skip(reason="test is pipeline test")(test_case) + else: + try: + import pytest # We don't need a hard dependency on pytest in the main library + except ImportError: + return test_case + else: + return pytest.mark.is_pipeline_test()(test_case) + + +def is_agent_test(test_case): + """ + Decorator marking a test as an agent test. If RUN_TOOL_TESTS is set to a falsy value, those tests will be skipped. + """ + if not _run_agent_tests: + return unittest.skip(reason="test is an agent test")(test_case) + else: + try: + import pytest # We don't need a hard dependency on pytest in the main library + except ImportError: + return test_case + else: + return pytest.mark.is_agent_test()(test_case) + + +def slow(test_case): + """ + Decorator marking a test as slow. + + Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. + + """ + return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) + + +def tooslow(test_case): + """ + Decorator marking a test as too slow. + + Slow tests are skipped while they're in the process of being fixed. No test should stay tagged as "tooslow" as + these will not be tested by the CI. + + """ + return unittest.skip(reason="test is too slow")(test_case) + + +def skip_if_not_implemented(test_func): + @functools.wraps(test_func) + def wrapper(*args, **kwargs): + try: + return test_func(*args, **kwargs) + except NotImplementedError as e: + raise unittest.SkipTest(f"Test skipped due to NotImplementedError: {e}") + + return wrapper + + +def apply_skip_if_not_implemented(cls): + """ + Class decorator to apply @skip_if_not_implemented to all test methods. + """ + for attr_name in dir(cls): + if attr_name.startswith("test_"): + attr = getattr(cls, attr_name) + if callable(attr): + setattr(cls, attr_name, skip_if_not_implemented(attr)) + return cls + + +def custom_tokenizers(test_case): + """ + Decorator marking a test for a custom tokenizer. + + Custom tokenizers require additional dependencies, and are skipped by default. Set the RUN_CUSTOM_TOKENIZERS + environment variable to a truthy value to run them. + """ + return unittest.skipUnless(_run_custom_tokenizers, "test of custom tokenizers")(test_case) + + +def require_bs4(test_case): + """ + Decorator marking a test that requires BeautifulSoup4. These tests are skipped when BeautifulSoup4 isn't installed. + """ + return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case) + + +def require_galore_torch(test_case): + """ + Decorator marking a test that requires GaLore. These tests are skipped when GaLore isn't installed. + https://github.com/jiaweizzhao/GaLore + """ + return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case) + + +def require_apollo_torch(test_case): + """ + Decorator marking a test that requires GaLore. These tests are skipped when APOLLO isn't installed. + https://github.com/zhuhanqing/APOLLO + """ + return unittest.skipUnless(is_apollo_torch_available(), "test requires APOLLO")(test_case) + + +def require_torch_optimi(test_case): + """ + Decorator marking a test that requires torch-optimi. These tests are skipped when torch-optimi isn't installed. + https://github.com/jxnl/torch-optimi + """ + return unittest.skipUnless(is_torch_optimi_available(), "test requires torch-optimi")(test_case) + + +def require_lomo(test_case): + """ + Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed. + https://github.com/OpenLMLab/LOMO + """ + return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case) + + +def require_grokadamw(test_case): + """ + Decorator marking a test that requires GrokAdamW. These tests are skipped when GrokAdamW isn't installed. + """ + return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case) + + +def require_schedulefree(test_case): + """ + Decorator marking a test that requires schedulefree. These tests are skipped when schedulefree isn't installed. + https://github.com/facebookresearch/schedule_free + """ + return unittest.skipUnless(is_schedulefree_available(), "test requires schedulefree")(test_case) + + +def require_cv2(test_case): + """ + Decorator marking a test that requires OpenCV. + + These tests are skipped when OpenCV isn't installed. + + """ + return unittest.skipUnless(is_cv2_available(), "test requires OpenCV")(test_case) + + +def require_levenshtein(test_case): + """ + Decorator marking a test that requires Levenshtein. + + These tests are skipped when Levenshtein isn't installed. + + """ + return unittest.skipUnless(is_levenshtein_available(), "test requires Levenshtein")(test_case) + + +def require_nltk(test_case): + """ + Decorator marking a test that requires NLTK. + + These tests are skipped when NLTK isn't installed. + + """ + return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case) + + +def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION): + """ + Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed. + """ + return unittest.skipUnless( + is_accelerate_available(min_version), f"test requires accelerate version >= {min_version}" + )(test_case) + + +def require_triton(min_version: str = TRITON_MIN_VERSION): + """ + Decorator marking a test that requires triton. These tests are skipped when triton isn't installed. + """ + + def decorator(test_case): + return unittest.skipUnless(is_triton_available(min_version), f"test requires triton version >= {min_version}")( + test_case + ) + + return decorator + + +def require_gguf(test_case, min_version: str = GGUF_MIN_VERSION): + """ + Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed. + """ + return unittest.skipUnless(is_gguf_available(min_version), f"test requires gguf version >= {min_version}")( + test_case + ) + + +def require_fsdp(test_case, min_version: str = "1.12.0"): + """ + Decorator marking a test that requires fsdp. These tests are skipped when fsdp isn't installed. + """ + return unittest.skipUnless(is_fsdp_available(min_version), f"test requires torch version >= {min_version}")( + test_case + ) + + +def require_g2p_en(test_case): + """ + Decorator marking a test that requires g2p_en. These tests are skipped when SentencePiece isn't installed. + """ + return unittest.skipUnless(is_g2p_en_available(), "test requires g2p_en")(test_case) + + +def require_safetensors(test_case): + """ + Decorator marking a test that requires safetensors. These tests are skipped when safetensors isn't installed. + """ + return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case) + + +def require_rjieba(test_case): + """ + Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed. + """ + return unittest.skipUnless(is_rjieba_available(), "test requires rjieba")(test_case) + + +def require_jinja(test_case): + """ + Decorator marking a test that requires jinja. These tests are skipped when jinja isn't installed. + """ + return unittest.skipUnless(is_jinja_available(), "test requires jinja")(test_case) + + +def require_onnx(test_case): + return unittest.skipUnless(is_onnx_available(), "test requires ONNX")(test_case) + + +def require_timm(test_case): + """ + Decorator marking a test that requires Timm. + + These tests are skipped when Timm isn't installed. + + """ + return unittest.skipUnless(is_timm_available(), "test requires Timm")(test_case) + + +def require_natten(test_case): + """ + Decorator marking a test that requires NATTEN. + + These tests are skipped when NATTEN isn't installed. + + """ + return unittest.skipUnless(is_natten_available(), "test requires natten")(test_case) + + +def require_torch(test_case): + """ + Decorator marking a test that requires PyTorch. + + These tests are skipped when PyTorch isn't installed. + + """ + return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) + + +def require_torch_greater_or_equal(version: str): + """ + Decorator marking a test that requires PyTorch version >= `version`. + + These tests are skipped when PyTorch version is less than `version`. + """ + + def decorator(test_case): + return unittest.skipUnless(is_torch_greater_or_equal(version), f"test requires PyTorch version >= {version}")( + test_case + ) + + return decorator + + +def require_huggingface_hub_greater_or_equal(version: str): + """ + Decorator marking a test that requires huggingface_hub version >= `version`. + + These tests are skipped when huggingface_hub version is less than `version`. + """ + + def decorator(test_case): + return unittest.skipUnless( + is_huggingface_hub_greater_or_equal(version), f"test requires huggingface_hub version >= {version}" + )(test_case) + + return decorator + + +def require_flash_attn(test_case): + """ + Decorator marking a test that requires Flash Attention. + + These tests are skipped when Flash Attention isn't installed. + + """ + flash_attn_available = is_flash_attn_2_available() + kernels_available = is_kernels_available() + try: + from kernels import get_kernel + + get_kernel("kernels-community/flash-attn") + except Exception as _: + kernels_available = False + + return unittest.skipUnless(kernels_available | flash_attn_available, "test requires Flash Attention")(test_case) + + +def require_kernels(test_case): + """ + Decorator marking a test that requires the kernels library. + + These tests are skipped when the kernels library isn't installed. + + """ + return unittest.skipUnless(is_kernels_available(), "test requires the kernels library")(test_case) + + +def require_flash_attn_3(test_case): + """ + Decorator marking a test that requires Flash Attention 3. + + These tests are skipped when Flash Attention 3 isn't installed. + """ + return unittest.skipUnless(is_flash_attn_3_available(), "test requires Flash Attention 3")(test_case) + + +def require_read_token(test_case): + """ + A decorator that loads the HF token for tests that require to load gated models. + """ + token = os.getenv("HF_HUB_READ_TOKEN") + + if isinstance(test_case, type): + for attr_name in dir(test_case): + attr = getattr(test_case, attr_name) + if isinstance(attr, types.FunctionType): + if getattr(attr, "__require_read_token__", False): + continue + wrapped = require_read_token(attr) + setattr(test_case, attr_name, wrapped) + return test_case + else: + if getattr(test_case, "__require_read_token__", False): + return test_case + + @functools.wraps(test_case) + def wrapper(*args, **kwargs): + if token is not None: + with patch("huggingface_hub.utils._headers.get_token", return_value=token): + return test_case(*args, **kwargs) + else: # Allow running locally with the default token env variable + # dealing with static/class methods and called by `self.xxx` + if "staticmethod" in inspect.getsource(test_case).strip(): + if len(args) > 0 and isinstance(args[0], unittest.TestCase): + return test_case(*args[1:], **kwargs) + return test_case(*args, **kwargs) + + wrapper.__require_read_token__ = True + return wrapper + + +def require_peft(test_case): + """ + Decorator marking a test that requires PEFT. + + These tests are skipped when PEFT isn't installed. + + """ + return unittest.skipUnless(is_peft_available(), "test requires PEFT")(test_case) + + +def require_torchvision(test_case): + """ + Decorator marking a test that requires Torchvision. + + These tests are skipped when Torchvision isn't installed. + + """ + return unittest.skipUnless(is_torchvision_available(), "test requires Torchvision")(test_case) + + +def require_torchcodec(test_case): + """ + Decorator marking a test that requires Torchcodec. + + These tests are skipped when Torchcodec isn't installed. + + """ + return unittest.skipUnless(is_torchcodec_available(), "test requires Torchcodec")(test_case) + + +def require_torch_or_tf(test_case): + """ + Decorator marking a test that requires PyTorch or TensorFlow. + + These tests are skipped when neither PyTorch not TensorFlow is installed. + + """ + return unittest.skipUnless(is_torch_available() or is_tf_available(), "test requires PyTorch or TensorFlow")( + test_case + ) + + +def require_intel_extension_for_pytorch(test_case): + """ + Decorator marking a test that requires Intel Extension for PyTorch. + + These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch + version. + + """ + return unittest.skipUnless( + is_ipex_available(), + "test requires Intel Extension for PyTorch to be installed and match current PyTorch version, see" + " https://github.com/intel/intel-extension-for-pytorch", + )(test_case) + + +def require_torchaudio(test_case): + """ + Decorator marking a test that requires torchaudio. These tests are skipped when torchaudio isn't installed. + """ + return unittest.skipUnless(is_torchaudio_available(), "test requires torchaudio")(test_case) + + +def require_sentencepiece(test_case): + """ + Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed. + """ + return unittest.skipUnless(is_sentencepiece_available(), "test requires SentencePiece")(test_case) + + +def require_sacremoses(test_case): + """ + Decorator marking a test that requires Sacremoses. These tests are skipped when Sacremoses isn't installed. + """ + return unittest.skipUnless(is_sacremoses_available(), "test requires Sacremoses")(test_case) + + +def require_seqio(test_case): + """ + Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed. + """ + return unittest.skipUnless(is_seqio_available(), "test requires Seqio")(test_case) + + +def require_scipy(test_case): + """ + Decorator marking a test that requires Scipy. These tests are skipped when SentencePiece isn't installed. + """ + return unittest.skipUnless(is_scipy_available(), "test requires Scipy")(test_case) + + +def require_tokenizers(test_case): + """ + Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed. + """ + return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case) + + +def require_keras_nlp(test_case): + """ + Decorator marking a test that requires keras_nlp. These tests are skipped when keras_nlp isn't installed. + """ + return unittest.skipUnless(is_keras_nlp_available(), "test requires keras_nlp")(test_case) + + +def require_pandas(test_case): + """ + Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed. + """ + return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case) + + +def require_pytesseract(test_case): + """ + Decorator marking a test that requires PyTesseract. These tests are skipped when PyTesseract isn't installed. + """ + return unittest.skipUnless(is_pytesseract_available(), "test requires PyTesseract")(test_case) + + +def require_pytorch_quantization(test_case): + """ + Decorator marking a test that requires PyTorch Quantization Toolkit. These tests are skipped when PyTorch + Quantization Toolkit isn't installed. + """ + return unittest.skipUnless(is_pytorch_quantization_available(), "test requires PyTorch Quantization Toolkit")( + test_case + ) + + +def require_vision(test_case): + """ + Decorator marking a test that requires the vision dependencies. These tests are skipped when torchaudio isn't + installed. + """ + return unittest.skipUnless(is_vision_available(), "test requires vision")(test_case) + + +def require_ftfy(test_case): + """ + Decorator marking a test that requires ftfy. These tests are skipped when ftfy isn't installed. + """ + return unittest.skipUnless(is_ftfy_available(), "test requires ftfy")(test_case) + + +def require_spacy(test_case): + """ + Decorator marking a test that requires SpaCy. These tests are skipped when SpaCy isn't installed. + """ + return unittest.skipUnless(is_spacy_available(), "test requires spacy")(test_case) + + +def require_torch_multi_gpu(test_case): + """ + Decorator marking a test that requires a multi-GPU CUDA setup (in PyTorch). These tests are skipped on a machine without + multiple CUDA GPUs. + + To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu" + """ + if not is_torch_available(): + return unittest.skip(reason="test requires PyTorch")(test_case) + + import torch + + return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple CUDA GPUs")(test_case) + + +def require_torch_multi_accelerator(test_case): + """ + Decorator marking a test that requires a multi-accelerator (in PyTorch). These tests are skipped on a machine + without multiple accelerators. To run *only* the multi_accelerator tests, assuming all test names contain + multi_accelerator: $ pytest -sv ./tests -k "multi_accelerator" + """ + if not is_torch_available(): + return unittest.skip(reason="test requires PyTorch")(test_case) + + return unittest.skipUnless(backend_device_count(torch_device) > 1, "test requires multiple accelerators")( + test_case + ) + + +def require_torch_non_multi_gpu(test_case): + """ + Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch). + """ + if not is_torch_available(): + return unittest.skip(reason="test requires PyTorch")(test_case) + + import torch + + return unittest.skipUnless(torch.cuda.device_count() < 2, "test requires 0 or 1 GPU")(test_case) + + +def require_torch_non_multi_accelerator(test_case): + """ + Decorator marking a test that requires 0 or 1 accelerator setup (in PyTorch). + """ + if not is_torch_available(): + return unittest.skip(reason="test requires PyTorch")(test_case) + + return unittest.skipUnless(backend_device_count(torch_device) < 2, "test requires 0 or 1 accelerator")(test_case) + + +def require_torch_up_to_2_gpus(test_case): + """ + Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch). + """ + if not is_torch_available(): + return unittest.skip(reason="test requires PyTorch")(test_case) + + import torch + + return unittest.skipUnless(torch.cuda.device_count() < 3, "test requires 0 or 1 or 2 GPUs")(test_case) + + +def require_torch_up_to_2_accelerators(test_case): + """ + Decorator marking a test that requires 0 or 1 or 2 accelerator setup (in PyTorch). + """ + if not is_torch_available(): + return unittest.skip(reason="test requires PyTorch")(test_case) + + return unittest.skipUnless(backend_device_count(torch_device) < 3, "test requires 0 or 1 or 2 accelerators")( + test_case + ) + + +def require_torch_xla(test_case): + """ + Decorator marking a test that requires TorchXLA (in PyTorch). + """ + return unittest.skipUnless(is_torch_xla_available(), "test requires TorchXLA")(test_case) + + +def require_torch_neuroncore(test_case): + """ + Decorator marking a test that requires NeuronCore (in PyTorch). + """ + return unittest.skipUnless(is_torch_neuroncore_available(check_device=False), "test requires PyTorch NeuronCore")( + test_case + ) + + +def require_torch_npu(test_case): + """ + Decorator marking a test that requires NPU (in PyTorch). + """ + return unittest.skipUnless(is_torch_npu_available(), "test requires PyTorch NPU")(test_case) + + +def require_torch_multi_npu(test_case): + """ + Decorator marking a test that requires a multi-NPU setup (in PyTorch). These tests are skipped on a machine without + multiple NPUs. + + To run *only* the multi_npu tests, assuming all test names contain multi_npu: $ pytest -sv ./tests -k "multi_npu" + """ + if not is_torch_npu_available(): + return unittest.skip(reason="test requires PyTorch NPU")(test_case) + + return unittest.skipUnless(torch.npu.device_count() > 1, "test requires multiple NPUs")(test_case) + + +def require_non_hpu(test_case): + """ + Decorator marking a test that should be skipped for HPU. + """ + return unittest.skipUnless(torch_device != "hpu", "test requires a non-HPU")(test_case) + + +def require_torch_xpu(test_case): + """ + Decorator marking a test that requires XPU (in PyTorch). + + These tests are skipped when XPU backend is not available. XPU backend might be available either via stock + PyTorch (>=2.4) or via Intel Extension for PyTorch. In the latter case, if IPEX is installed, its version + must match match current PyTorch version. + """ + return unittest.skipUnless(is_torch_xpu_available(), "test requires XPU device")(test_case) + + +def require_non_xpu(test_case): + """ + Decorator marking a test that should be skipped for XPU. + """ + return unittest.skipUnless(torch_device != "xpu", "test requires a non-XPU")(test_case) + + +def require_torch_multi_xpu(test_case): + """ + Decorator marking a test that requires a multi-XPU setup (in PyTorch). These tests are skipped on a machine without + multiple XPUs. + + To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu" + """ + if not is_torch_xpu_available(): + return unittest.skip(reason="test requires PyTorch XPU")(test_case) + + return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case) + + +def require_torch_multi_hpu(test_case): + """ + Decorator marking a test that requires a multi-HPU setup (in PyTorch). These tests are skipped on a machine without + multiple HPUs. + + To run *only* the multi_hpu tests, assuming all test names contain multi_hpu: $ pytest -sv ./tests -k "multi_hpu" + """ + if not is_torch_hpu_available(): + return unittest.skip(reason="test requires PyTorch HPU")(test_case) + + return unittest.skipUnless(torch.hpu.device_count() > 1, "test requires multiple HPUs")(test_case) + + +if is_torch_available(): + # Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode + import torch + + if "TRANSFORMERS_TEST_BACKEND" in os.environ: + backend = os.environ["TRANSFORMERS_TEST_BACKEND"] + try: + _ = importlib.import_module(backend) + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its" + f" traceback):\n{e}" + ) from e + + if "TRANSFORMERS_TEST_DEVICE" in os.environ: + torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"] + if torch_device == "cuda" and not torch.cuda.is_available(): + raise ValueError( + f"TRANSFORMERS_TEST_DEVICE={torch_device}, but CUDA is unavailable. Please double-check your testing environment." + ) + if torch_device == "xpu" and not is_torch_xpu_available(): + raise ValueError( + f"TRANSFORMERS_TEST_DEVICE={torch_device}, but XPU is unavailable. Please double-check your testing environment." + ) + if torch_device == "npu" and not is_torch_npu_available(): + raise ValueError( + f"TRANSFORMERS_TEST_DEVICE={torch_device}, but NPU is unavailable. Please double-check your testing environment." + ) + if torch_device == "mlu" and not is_torch_mlu_available(): + raise ValueError( + f"TRANSFORMERS_TEST_DEVICE={torch_device}, but MLU is unavailable. Please double-check your testing environment." + ) + if torch_device == "hpu" and not is_torch_hpu_available(): + raise ValueError( + f"TRANSFORMERS_TEST_DEVICE={torch_device}, but HPU is unavailable. Please double-check your testing environment." + ) + + try: + # try creating device to see if provided device is valid + _ = torch.device(torch_device) + except RuntimeError as e: + raise RuntimeError( + f"Unknown testing device specified by environment variable `TRANSFORMERS_TEST_DEVICE`: {torch_device}" + ) from e + elif torch.cuda.is_available(): + torch_device = "cuda" + elif is_torch_npu_available(): + torch_device = "npu" + elif is_torch_mlu_available(): + torch_device = "mlu" + elif is_torch_hpu_available(): + torch_device = "hpu" + elif is_torch_xpu_available(): + torch_device = "xpu" + else: + torch_device = "cpu" +else: + torch_device = None + +if is_tf_available(): + import tensorflow as tf + +if is_flax_available(): + import jax + + jax_device = jax.default_backend() +else: + jax_device = None + + +def require_torchdynamo(test_case): + """Decorator marking a test that requires TorchDynamo""" + return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case) + + +def require_torchao(test_case): + """Decorator marking a test that requires torchao""" + return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case) + + +def require_torchao_version_greater_or_equal(torchao_version): + def decorator(test_case): + correct_torchao_version = is_torchao_available() and version.parse( + version.parse(importlib.metadata.version("torchao")).base_version + ) >= version.parse(torchao_version) + return unittest.skipUnless( + correct_torchao_version, f"Test requires torchao with the version greater than {torchao_version}." + )(test_case) + + return decorator + + +def require_torch_tensorrt_fx(test_case): + """Decorator marking a test that requires Torch-TensorRT FX""" + return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case) + + +def require_torch_gpu(test_case): + """Decorator marking a test that requires CUDA and PyTorch.""" + return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case) + + +def require_torch_mps(test_case): + """Decorator marking a test that requires CUDA and PyTorch.""" + return unittest.skipUnless(torch_device == "mps", "test requires MPS")(test_case) + + +def require_large_cpu_ram(test_case, memory: float = 80): + """Decorator marking a test that requires a CPU RAM with more than `memory` GiB of memory.""" + if not is_psutil_available(): + return test_case + + import psutil + + return unittest.skipUnless( + psutil.virtual_memory().total / 1024**3 > memory, + f"test requires a machine with more than {memory} GiB of CPU RAM memory", + )(test_case) + + +def require_torch_large_gpu(test_case, memory: float = 20): + """Decorator marking a test that requires a CUDA GPU with more than `memory` GiB of memory.""" + if torch_device != "cuda": + return unittest.skip(reason=f"test requires a CUDA GPU with more than {memory} GiB of memory")(test_case) + + return unittest.skipUnless( + torch.cuda.get_device_properties(0).total_memory / 1024**3 > memory, + f"test requires a GPU with more than {memory} GiB of memory", + )(test_case) + + +def require_torch_large_accelerator(test_case, memory: float = 20): + """Decorator marking a test that requires an accelerator with more than `memory` GiB of memory.""" + if torch_device != "cuda" and torch_device != "xpu": + return unittest.skip(reason=f"test requires a GPU or XPU with more than {memory} GiB of memory")(test_case) + + torch_accelerator_module = getattr(torch, torch_device) + + return unittest.skipUnless( + torch_accelerator_module.get_device_properties(0).total_memory / 1024**3 > memory, + f"test requires a GPU or XPU with more than {memory} GiB of memory", + )(test_case) + + +def require_torch_gpu_if_bnb_not_multi_backend_enabled(test_case): + """ + Decorator marking a test that requires a GPU if bitsandbytes multi-backend feature is not enabled. + """ + if is_bitsandbytes_available() and is_bitsandbytes_multi_backend_available(): + return test_case + return require_torch_gpu(test_case) + + +def require_torch_accelerator(test_case): + """Decorator marking a test that requires an accessible accelerator and PyTorch.""" + return unittest.skipUnless(torch_device is not None and torch_device != "cpu", "test requires accelerator")( + test_case + ) + + +def require_torch_fp16(test_case): + """Decorator marking a test that requires a device that supports fp16""" + return unittest.skipUnless( + is_torch_fp16_available_on_device(torch_device), "test requires device with fp16 support" + )(test_case) + + +def require_fp8(test_case): + """Decorator marking a test that requires supports for fp8""" + return unittest.skipUnless(is_accelerate_available() and is_fp8_available(), "test requires fp8 support")( + test_case + ) + + +def require_torch_bf16(test_case): + """Decorator marking a test that requires a device that supports bf16""" + return unittest.skipUnless( + is_torch_bf16_available_on_device(torch_device), "test requires device with bf16 support" + )(test_case) + + +def require_torch_bf16_gpu(test_case): + """Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0""" + return unittest.skipUnless( + is_torch_bf16_gpu_available(), + "test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0", + )(test_case) + + +def require_deterministic_for_xpu(test_case): + @wraps(test_case) + def wrapper(*args, **kwargs): + if is_torch_xpu_available(): + original_state = torch.are_deterministic_algorithms_enabled() + try: + torch.use_deterministic_algorithms(True) + return test_case(*args, **kwargs) + finally: + torch.use_deterministic_algorithms(original_state) + else: + return test_case(*args, **kwargs) + + return wrapper + + +def require_torch_tf32(test_case): + """Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7.""" + return unittest.skipUnless( + is_torch_tf32_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7" + )(test_case) + + +def require_detectron2(test_case): + """Decorator marking a test that requires detectron2.""" + return unittest.skipUnless(is_detectron2_available(), "test requires `detectron2`")(test_case) + + +def require_faiss(test_case): + """Decorator marking a test that requires faiss.""" + return unittest.skipUnless(is_faiss_available(), "test requires `faiss`")(test_case) + + +def require_optuna(test_case): + """ + Decorator marking a test that requires optuna. + + These tests are skipped when optuna isn't installed. + + """ + return unittest.skipUnless(is_optuna_available(), "test requires optuna")(test_case) + + +def require_ray(test_case): + """ + Decorator marking a test that requires Ray/tune. + + These tests are skipped when Ray/tune isn't installed. + + """ + return unittest.skipUnless(is_ray_available(), "test requires Ray/tune")(test_case) + + +def require_sigopt(test_case): + """ + Decorator marking a test that requires SigOpt. + + These tests are skipped when SigOpt isn't installed. + + """ + return unittest.skipUnless(is_sigopt_available(), "test requires SigOpt")(test_case) + + +def require_swanlab(test_case): + """ + Decorator marking a test that requires swanlab. + + These tests are skipped when swanlab isn't installed. + + """ + return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")(test_case) + + +def require_trackio(test_case): + """ + Decorator marking a test that requires trackio. + + These tests are skipped when trackio isn't installed. + + """ + return unittest.skipUnless(is_trackio_available(), "test requires trackio")(test_case) + + +def require_wandb(test_case): + """ + Decorator marking a test that requires wandb. + + These tests are skipped when wandb isn't installed. + + """ + return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case) + + +def require_clearml(test_case): + """ + Decorator marking a test requires clearml. + + These tests are skipped when clearml isn't installed. + + """ + return unittest.skipUnless(is_clearml_available(), "test requires clearml")(test_case) + + +def require_deepspeed(test_case): + """ + Decorator marking a test that requires deepspeed + """ + return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case) + + +def require_apex(test_case): + """ + Decorator marking a test that requires apex + """ + return unittest.skipUnless(is_apex_available(), "test requires apex")(test_case) + + +def require_aqlm(test_case): + """ + Decorator marking a test that requires aqlm + """ + return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case) + + +def require_vptq(test_case): + """ + Decorator marking a test that requires vptq + """ + return unittest.skipUnless(is_vptq_available(), "test requires vptq")(test_case) + + +def require_spqr(test_case): + """ + Decorator marking a test that requires spqr + """ + return unittest.skipUnless(is_spqr_available(), "test requires spqr")(test_case) + + +def require_eetq(test_case): + """ + Decorator marking a test that requires eetq + """ + eetq_available = is_eetq_available() + if eetq_available: + try: + import eetq # noqa: F401 + except ImportError as exc: + if "shard_checkpoint" in str(exc): + # EETQ 1.0.0 is currently broken with the latest transformers because it tries to import the removed + # shard_checkpoint function, see https://github.com/NetEase-FuXi/EETQ/issues/34. + # TODO: Remove once eetq releases a fix and this release is used in CI + eetq_available = False + return unittest.skipUnless(eetq_available, "test requires eetq")(test_case) + + +def require_av(test_case): + """ + Decorator marking a test that requires av + """ + return unittest.skipUnless(is_av_available(), "test requires av")(test_case) + + +def require_decord(test_case): + """ + Decorator marking a test that requires decord + """ + return unittest.skipUnless(is_decord_available(), "test requires decord")(test_case) + + +def require_bitsandbytes(test_case): + """ + Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library or its hard dependency torch is not installed. + """ + if is_bitsandbytes_available() and is_torch_available(): + try: + import pytest + + return pytest.mark.bitsandbytes(test_case) + except ImportError: + return test_case + else: + return unittest.skip(reason="test requires bitsandbytes and torch")(test_case) + + +def require_optimum(test_case): + """ + Decorator for optimum dependency + """ + return unittest.skipUnless(is_optimum_available(), "test requires optimum")(test_case) + + +def require_tensorboard(test_case): + """ + Decorator for `tensorboard` dependency + """ + return unittest.skipUnless(is_tensorboard_available(), "test requires tensorboard") + + +def require_gptq(test_case): + """ + Decorator for auto_gptq dependency + """ + return unittest.skipUnless( + is_gptqmodel_available() or is_auto_gptq_available(), "test requires gptqmodel or auto-gptq" + )(test_case) + + +def require_hqq(test_case): + """ + Decorator for hqq dependency + """ + return unittest.skipUnless(is_hqq_available(), "test requires hqq")(test_case) + + +def require_auto_awq(test_case): + """ + Decorator for auto_awq dependency + """ + return unittest.skipUnless(is_auto_awq_available(), "test requires autoawq")(test_case) + + +def require_auto_round(test_case): + """ + Decorator for auto_round dependency + """ + return unittest.skipUnless(is_auto_round_available(), "test requires autoround")(test_case) + + +def require_optimum_quanto(test_case): + """ + Decorator for quanto dependency + """ + return unittest.skipUnless(is_optimum_quanto_available(), "test requires optimum-quanto")(test_case) + + +def require_compressed_tensors(test_case): + """ + Decorator for compressed_tensors dependency + """ + return unittest.skipUnless(is_compressed_tensors_available(), "test requires compressed_tensors")(test_case) + + +def require_fbgemm_gpu(test_case): + """ + Decorator for fbgemm_gpu dependency + """ + return unittest.skipUnless(is_fbgemm_gpu_available(), "test requires fbgemm-gpu")(test_case) + + +def require_quark(test_case): + """ + Decorator for quark dependency + """ + return unittest.skipUnless(is_quark_available(), "test requires quark")(test_case) + + +def require_flute_hadamard(test_case): + """ + Decorator marking a test that requires higgs and hadamard + """ + return unittest.skipUnless( + is_flute_available() and is_hadamard_available(), "test requires flute and fast_hadamard_transform" + )(test_case) + + +def require_fp_quant(test_case): + """ + Decorator marking a test that requires fp_quant and qutlass + """ + return unittest.skipUnless(is_fp_quant_available(), "test requires fp_quant")(test_case) + + +def require_qutlass(test_case): + """ + Decorator marking a test that requires qutlass + """ + return unittest.skipUnless(is_qutlass_available(), "test requires qutlass")(test_case) + + +def require_phonemizer(test_case): + """ + Decorator marking a test that requires phonemizer + """ + return unittest.skipUnless(is_phonemizer_available(), "test requires phonemizer")(test_case) + + +def require_pyctcdecode(test_case): + """ + Decorator marking a test that requires pyctcdecode + """ + return unittest.skipUnless(is_pyctcdecode_available(), "test requires pyctcdecode")(test_case) + + +def require_librosa(test_case): + """ + Decorator marking a test that requires librosa + """ + return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case) + + +def require_liger_kernel(test_case): + """ + Decorator marking a test that requires liger_kernel + """ + return unittest.skipUnless(is_liger_kernel_available(), "test requires liger_kernel")(test_case) + + +def require_essentia(test_case): + """ + Decorator marking a test that requires essentia + """ + return unittest.skipUnless(is_essentia_available(), "test requires essentia")(test_case) + + +def require_pretty_midi(test_case): + """ + Decorator marking a test that requires pretty_midi + """ + return unittest.skipUnless(is_pretty_midi_available(), "test requires pretty_midi")(test_case) + + +def cmd_exists(cmd): + return shutil.which(cmd) is not None + + +def require_usr_bin_time(test_case): + """ + Decorator marking a test that requires `/usr/bin/time` + """ + return unittest.skipUnless(cmd_exists("/usr/bin/time"), "test requires /usr/bin/time")(test_case) + + +def require_sudachi(test_case): + """ + Decorator marking a test that requires sudachi + """ + return unittest.skipUnless(is_sudachi_available(), "test requires sudachi")(test_case) + + +def require_sudachi_projection(test_case): + """ + Decorator marking a test that requires sudachi_projection + """ + return unittest.skipUnless(is_sudachi_projection_available(), "test requires sudachi which supports projection")( + test_case + ) + + +def require_jumanpp(test_case): + """ + Decorator marking a test that requires jumanpp + """ + return unittest.skipUnless(is_jumanpp_available(), "test requires jumanpp")(test_case) + + +def require_cython(test_case): + """ + Decorator marking a test that requires jumanpp + """ + return unittest.skipUnless(is_cython_available(), "test requires cython")(test_case) + + +def require_tiktoken(test_case): + """ + Decorator marking a test that requires TikToken. These tests are skipped when TikToken isn't installed. + """ + return unittest.skipUnless(is_tiktoken_available(), "test requires TikToken")(test_case) + + +def require_speech(test_case): + """ + Decorator marking a test that requires speech. These tests are skipped when speech isn't available. + """ + return unittest.skipUnless(is_speech_available(), "test requires torchaudio")(test_case) + + +def require_openai(test_case): + """ + Decorator marking a test that requires openai + """ + return unittest.skipUnless(is_openai_available(), "test requires openai")(test_case) + + +def require_mistral_common(test_case): + """ + Decorator marking a test that requires mistral-common. These tests are skipped when mistral-common isn't available. + """ + return unittest.skipUnless(is_mistral_common_available(), "test requires mistral-common")(test_case) + + +def get_gpu_count(): + """ + Return the number of available gpus (regardless of whether torch, tf or jax is used) + """ + if is_torch_available(): + import torch + + return torch.cuda.device_count() + elif is_tf_available(): + import tensorflow as tf + + return len(tf.config.list_physical_devices("GPU")) + elif is_flax_available(): + import jax + + return jax.device_count() + else: + return 0 + + +def get_tests_dir(append_path=None): + """ + Args: + append_path: optional path to append to the tests dir path + + Return: + The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is + joined after the `tests` dir the former is provided. + + """ + # this function caller's __file__ + caller__file__ = inspect.stack()[1][1] + tests_dir = os.path.abspath(os.path.dirname(caller__file__)) + + while not tests_dir.endswith("tests"): + tests_dir = os.path.dirname(tests_dir) + + if append_path: + return os.path.join(tests_dir, append_path) + else: + return tests_dir + + +def get_steps_per_epoch(trainer: Trainer) -> int: + training_args = trainer.args + train_dataloader = trainer.get_train_dataloader() + + initial_training_values = trainer.set_initial_training_values( + args=training_args, + dataloader=train_dataloader, + total_train_batch_size=training_args.per_device_train_batch_size, + ) + steps_per_epoch = initial_training_values[1] + + return steps_per_epoch + + +def evaluate_side_effect_factory( + side_effect_values: list[dict[str, float]], +) -> Generator[dict[str, float], None, None]: + """ + Function that returns side effects for the _evaluate method. + Used when we're unsure of exactly how many times _evaluate will be called. + """ + yield from side_effect_values + + while True: + yield side_effect_values[-1] + + +# +# Helper functions for dealing with testing text outputs +# The original code came from: +# https://github.com/fastai/fastai/blob/master/tests/utils/text.py + + +# When any function contains print() calls that get overwritten, like progress bars, +# a special care needs to be applied, since under pytest -s captured output (capsys +# or contextlib.redirect_stdout) contains any temporary printed strings, followed by +# \r's. This helper function ensures that the buffer will contain the same output +# with and without -s in pytest, by turning: +# foo bar\r tar mar\r final message +# into: +# final message +# it can handle a single string or a multiline buffer +def apply_print_resets(buf): + return re.sub(r"^.*\r", "", buf, 0, re.MULTILINE) + + +def assert_screenout(out, what): + out_pr = apply_print_resets(out).lower() + match_str = out_pr.find(what.lower()) + assert match_str != -1, f"expecting to find {what} in output: f{out_pr}" + + +def set_config_for_less_flaky_test(config): + target_attrs = [ + "rms_norm_eps", + "layer_norm_eps", + "norm_eps", + "norm_epsilon", + "layer_norm_epsilon", + "batch_norm_eps", + ] + for target_attr in target_attrs: + setattr(config, target_attr, 1.0) + + # norm layers (layer/group norm, etc.) could cause flaky tests when the tensors have very small variance. + # (We don't need the original epsilon values to check eager/sdpa matches) + attrs = ["text_config", "vision_config", "text_encoder", "audio_encoder", "decoder"] + for attr in attrs: + if hasattr(config, attr): + for target_attr in target_attrs: + setattr(getattr(config, attr), target_attr, 1.0) + + +def set_model_for_less_flaky_test(model): + # Another way to make sure norm layers have desired epsilon. (Some models don't set it from its config.) + target_names = ( + "LayerNorm", + "GroupNorm", + "BatchNorm", + "RMSNorm", + "BatchNorm2d", + "BatchNorm1d", + "BitGroupNormActivation", + "WeightStandardizedConv2d", + ) + target_attrs = ["eps", "epsilon", "variance_epsilon"] + if is_torch_available() and isinstance(model, torch.nn.Module): + for module in model.modules(): + if type(module).__name__.endswith(target_names): + for attr in target_attrs: + if hasattr(module, attr): + setattr(module, attr, 1.0) + + +class CaptureStd: + """ + Context manager to capture: + + - stdout: replay it, clean it up and make it available via `obj.out` + - stderr: replay it and make it available via `obj.err` + + Args: + out (`bool`, *optional*, defaults to `True`): Whether to capture stdout or not. + err (`bool`, *optional*, defaults to `True`): Whether to capture stderr or not. + replay (`bool`, *optional*, defaults to `True`): Whether to replay or not. + By default each captured stream gets replayed back on context's exit, so that one can see what the test was + doing. If this is a not wanted behavior and the captured data shouldn't be replayed, pass `replay=False` to + disable this feature. + + Examples: + + ```python + # to capture stdout only with auto-replay + with CaptureStdout() as cs: + print("Secret message") + assert "message" in cs.out + + # to capture stderr only with auto-replay + import sys + + with CaptureStderr() as cs: + print("Warning: ", file=sys.stderr) + assert "Warning" in cs.err + + # to capture both streams with auto-replay + with CaptureStd() as cs: + print("Secret message") + print("Warning: ", file=sys.stderr) + assert "message" in cs.out + assert "Warning" in cs.err + + # to capture just one of the streams, and not the other, with auto-replay + with CaptureStd(err=False) as cs: + print("Secret message") + assert "message" in cs.out + # but best use the stream-specific subclasses + + # to capture without auto-replay + with CaptureStd(replay=False) as cs: + print("Secret message") + assert "message" in cs.out + ```""" + + def __init__(self, out=True, err=True, replay=True): + self.replay = replay + + if out: + self.out_buf = StringIO() + self.out = "error: CaptureStd context is unfinished yet, called too early" + else: + self.out_buf = None + self.out = "not capturing stdout" + + if err: + self.err_buf = StringIO() + self.err = "error: CaptureStd context is unfinished yet, called too early" + else: + self.err_buf = None + self.err = "not capturing stderr" + + def __enter__(self): + if self.out_buf: + self.out_old = sys.stdout + sys.stdout = self.out_buf + + if self.err_buf: + self.err_old = sys.stderr + sys.stderr = self.err_buf + + return self + + def __exit__(self, *exc): + if self.out_buf: + sys.stdout = self.out_old + captured = self.out_buf.getvalue() + if self.replay: + sys.stdout.write(captured) + self.out = apply_print_resets(captured) + + if self.err_buf: + sys.stderr = self.err_old + captured = self.err_buf.getvalue() + if self.replay: + sys.stderr.write(captured) + self.err = captured + + def __repr__(self): + msg = "" + if self.out_buf: + msg += f"stdout: {self.out}\n" + if self.err_buf: + msg += f"stderr: {self.err}\n" + return msg + + +# in tests it's the best to capture only the stream that's wanted, otherwise +# it's easy to miss things, so unless you need to capture both streams, use the +# subclasses below (less typing). Or alternatively, configure `CaptureStd` to +# disable the stream you don't need to test. + + +class CaptureStdout(CaptureStd): + """Same as CaptureStd but captures only stdout""" + + def __init__(self, replay=True): + super().__init__(err=False, replay=replay) + + +class CaptureStderr(CaptureStd): + """Same as CaptureStd but captures only stderr""" + + def __init__(self, replay=True): + super().__init__(out=False, replay=replay) + + +class CaptureLogger: + """ + Context manager to capture `logging` streams + + Args: + logger: 'logging` logger object + + Returns: + The captured output is available via `self.out` + + Example: + + ```python + >>> from transformers import logging + >>> from transformers.testing_utils import CaptureLogger + + >>> msg = "Testing 1, 2, 3" + >>> logging.set_verbosity_info() + >>> logger = logging.get_logger("transformers.models.bart.tokenization_bart") + >>> with CaptureLogger(logger) as cl: + ... logger.info(msg) + >>> assert cl.out, msg + "\n" + ``` + """ + + def __init__(self, logger): + self.logger = logger + self.io = StringIO() + self.sh = logging.StreamHandler(self.io) + self.out = "" + + def __enter__(self): + self.logger.addHandler(self.sh) + return self + + def __exit__(self, *exc): + self.logger.removeHandler(self.sh) + self.out = self.io.getvalue() + + def __repr__(self): + return f"captured: {self.out}\n" + + +@contextlib.contextmanager +def LoggingLevel(level): + """ + This is a context manager to temporarily change transformers modules logging level to the desired value and have it + restored to the original setting at the end of the scope. + + Example: + + ```python + with LoggingLevel(logging.INFO): + AutoModel.from_pretrained("openai-community/gpt2") # calls logger.info() several times + ``` + """ + orig_level = transformers_logging.get_verbosity() + try: + transformers_logging.set_verbosity(level) + yield + finally: + transformers_logging.set_verbosity(orig_level) + + +class TemporaryHubRepo: + """Create a temporary Hub repository and return its `RepoUrl` object. This is similar to + `tempfile.TemporaryDirectory` and can be used as a context manager. For example: + + with TemporaryHubRepo(token=self._token) as temp_repo: + ... + + Upon exiting the context, the repository and everything contained in it are removed. + + Example: + + ```python + with TemporaryHubRepo(token=self._token) as temp_repo: + model.push_to_hub(tmp_repo.repo_id, token=self._token) + ``` + """ + + def __init__(self, namespace: Optional[str] = None, token: Optional[str] = None) -> None: + self.token = token + with tempfile.TemporaryDirectory() as tmp_dir: + repo_id = Path(tmp_dir).name + if namespace is not None: + repo_id = f"{namespace}/{repo_id}" + self.repo_url = huggingface_hub.create_repo(repo_id, token=self.token) + + def __enter__(self): + return self.repo_url + + def __exit__(self, exc, value, tb): + delete_repo(repo_id=self.repo_url.repo_id, token=self.token, missing_ok=True) + + +@contextlib.contextmanager +# adapted from https://stackoverflow.com/a/64789046/9201239 +def ExtendSysPath(path: Union[str, os.PathLike]) -> Iterator[None]: + """ + Temporary add given path to `sys.path`. + + Usage : + + ```python + with ExtendSysPath("/path/to/dir"): + mymodule = importlib.import_module("mymodule") + ``` + """ + + path = os.fspath(path) + try: + sys.path.insert(0, path) + yield + finally: + sys.path.remove(path) + + +class TestCasePlus(unittest.TestCase): + """ + This class extends *unittest.TestCase* with additional features. + + Feature 1: A set of fully resolved important file and dir path accessors. + + In tests often we need to know where things are relative to the current test file, and it's not trivial since the + test could be invoked from more than one directory or could reside in sub-directories with different depths. This + class solves this problem by sorting out all the basic paths and provides easy accessors to them: + + - `pathlib` objects (all fully resolved): + + - `test_file_path` - the current test file path (=`__file__`) + - `test_file_dir` - the directory containing the current test file + - `tests_dir` - the directory of the `tests` test suite + - `examples_dir` - the directory of the `examples` test suite + - `repo_root_dir` - the directory of the repository + - `src_dir` - the directory of `src` (i.e. where the `transformers` sub-dir resides) + + - stringified paths---same as above but these return paths as strings, rather than `pathlib` objects: + + - `test_file_path_str` + - `test_file_dir_str` + - `tests_dir_str` + - `examples_dir_str` + - `repo_root_dir_str` + - `src_dir_str` + + Feature 2: Flexible auto-removable temporary dirs which are guaranteed to get removed at the end of test. + + 1. Create a unique temporary dir: + + ```python + def test_whatever(self): + tmp_dir = self.get_auto_remove_tmp_dir() + ``` + + `tmp_dir` will contain the path to the created temporary dir. It will be automatically removed at the end of the + test. + + + 2. Create a temporary dir of my choice, ensure it's empty before the test starts and don't + empty it after the test. + + ```python + def test_whatever(self): + tmp_dir = self.get_auto_remove_tmp_dir("./xxx") + ``` + + This is useful for debug when you want to monitor a specific directory and want to make sure the previous tests + didn't leave any data in there. + + 3. You can override the first two options by directly overriding the `before` and `after` args, leading to the + following behavior: + + `before=True`: the temporary dir will always be cleared at the beginning of the test. + + `before=False`: if the temporary dir already existed, any existing files will remain there. + + `after=True`: the temporary dir will always be deleted at the end of the test. + + `after=False`: the temporary dir will always be left intact at the end of the test. + + Note 1: In order to run the equivalent of `rm -r` safely, only subdirs of the project repository checkout are + allowed if an explicit `tmp_dir` is used, so that by mistake no `/tmp` or similar important part of the filesystem + will get nuked. i.e. please always pass paths that start with `./` + + Note 2: Each test can register multiple temporary dirs and they all will get auto-removed, unless requested + otherwise. + + Feature 3: Get a copy of the `os.environ` object that sets up `PYTHONPATH` specific to the current test suite. This + is useful for invoking external programs from the test suite - e.g. distributed training. + + + ```python + def test_whatever(self): + env = self.get_env() + ```""" + + def setUp(self): + # get_auto_remove_tmp_dir feature: + self.teardown_tmp_dirs = [] + + # figure out the resolved paths for repo_root, tests, examples, etc. + self._test_file_path = inspect.getfile(self.__class__) + path = Path(self._test_file_path).resolve() + self._test_file_dir = path.parents[0] + for up in [1, 2, 3]: + tmp_dir = path.parents[up] + if (tmp_dir / "src").is_dir() and (tmp_dir / "tests").is_dir(): + break + if tmp_dir: + self._repo_root_dir = tmp_dir + else: + raise ValueError(f"can't figure out the root of the repo from {self._test_file_path}") + self._tests_dir = self._repo_root_dir / "tests" + self._examples_dir = self._repo_root_dir / "examples" + self._src_dir = self._repo_root_dir / "src" + + @property + def test_file_path(self): + return self._test_file_path + + @property + def test_file_path_str(self): + return str(self._test_file_path) + + @property + def test_file_dir(self): + return self._test_file_dir + + @property + def test_file_dir_str(self): + return str(self._test_file_dir) + + @property + def tests_dir(self): + return self._tests_dir + + @property + def tests_dir_str(self): + return str(self._tests_dir) + + @property + def examples_dir(self): + return self._examples_dir + + @property + def examples_dir_str(self): + return str(self._examples_dir) + + @property + def repo_root_dir(self): + return self._repo_root_dir + + @property + def repo_root_dir_str(self): + return str(self._repo_root_dir) + + @property + def src_dir(self): + return self._src_dir + + @property + def src_dir_str(self): + return str(self._src_dir) + + def get_env(self): + """ + Return a copy of the `os.environ` object that sets up `PYTHONPATH` correctly, depending on the test suite it's + invoked from. This is useful for invoking external programs from the test suite - e.g. distributed training. + + It always inserts `./src` first, then `./tests` or `./examples` depending on the test suite type and finally + the preset `PYTHONPATH` if any (all full resolved paths). + + """ + env = os.environ.copy() + paths = [self.repo_root_dir_str, self.src_dir_str] + if "/examples" in self.test_file_dir_str: + paths.append(self.examples_dir_str) + else: + paths.append(self.tests_dir_str) + paths.append(env.get("PYTHONPATH", "")) + + env["PYTHONPATH"] = ":".join(paths) + return env + + def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None): + """ + Args: + tmp_dir (`string`, *optional*): + if `None`: + + - a unique temporary path will be created + - sets `before=True` if `before` is `None` + - sets `after=True` if `after` is `None` + else: + + - `tmp_dir` will be created + - sets `before=True` if `before` is `None` + - sets `after=False` if `after` is `None` + before (`bool`, *optional*): + If `True` and the `tmp_dir` already exists, make sure to empty it right away if `False` and the + `tmp_dir` already exists, any existing files will remain there. + after (`bool`, *optional*): + If `True`, delete the `tmp_dir` at the end of the test if `False`, leave the `tmp_dir` and its contents + intact at the end of the test. + + Returns: + tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir + """ + if tmp_dir is not None: + # defining the most likely desired behavior for when a custom path is provided. + # this most likely indicates the debug mode where we want an easily locatable dir that: + # 1. gets cleared out before the test (if it already exists) + # 2. is left intact after the test + if before is None: + before = True + if after is None: + after = False + + # using provided path + path = Path(tmp_dir).resolve() + + # to avoid nuking parts of the filesystem, only relative paths are allowed + if not tmp_dir.startswith("./"): + raise ValueError( + f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`" + ) + + # ensure the dir is empty to start with + if before is True and path.exists(): + shutil.rmtree(tmp_dir, ignore_errors=True) + + path.mkdir(parents=True, exist_ok=True) + + else: + # defining the most likely desired behavior for when a unique tmp path is auto generated + # (not a debug mode), here we require a unique tmp dir that: + # 1. is empty before the test (it will be empty in this situation anyway) + # 2. gets fully removed after the test + if before is None: + before = True + if after is None: + after = True + + # using unique tmp dir (always empty, regardless of `before`) + tmp_dir = tempfile.mkdtemp() + + if after is True: + # register for deletion + self.teardown_tmp_dirs.append(tmp_dir) + + return tmp_dir + + def python_one_liner_max_rss(self, one_liner_str): + """ + Runs the passed python one liner (just the code) and returns how much max cpu memory was used to run the + program. + + Args: + one_liner_str (`string`): + a python one liner code that gets passed to `python -c` + + Returns: + max cpu memory bytes used to run the program. This value is likely to vary slightly from run to run. + + Requirements: + this helper needs `/usr/bin/time` to be installed (`apt install time`) + + Example: + + ``` + one_liner_str = 'from transformers import AutoModel; AutoModel.from_pretrained("google-t5/t5-large")' + max_rss = self.python_one_liner_max_rss(one_liner_str) + ``` + """ + + if not cmd_exists("/usr/bin/time"): + raise ValueError("/usr/bin/time is required, install with `apt install time`") + + cmd = shlex.split(f"/usr/bin/time -f %M python -c '{one_liner_str}'") + with CaptureStd() as cs: + execute_subprocess_async(cmd, env=self.get_env()) + # returned data is in KB so convert to bytes + max_rss = int(cs.err.split("\n")[-2].replace("stderr: ", "")) * 1024 + return max_rss + + def tearDown(self): + # get_auto_remove_tmp_dir feature: remove registered temp dirs + for path in self.teardown_tmp_dirs: + shutil.rmtree(path, ignore_errors=True) + self.teardown_tmp_dirs = [] + if is_accelerate_available(): + AcceleratorState._reset_state() + PartialState._reset_state() + + # delete all the env variables having `ACCELERATE` in them + for k in list(os.environ.keys()): + if "ACCELERATE" in k: + del os.environ[k] + + +def mockenv(**kwargs): + """ + this is a convenience wrapper, that allows this :: + + @mockenv(RUN_SLOW=True, USE_TF=False) def test_something(): + run_slow = os.getenv("RUN_SLOW", False) use_tf = os.getenv("USE_TF", False) + + """ + return mock.patch.dict(os.environ, kwargs) + + +# from https://stackoverflow.com/a/34333710/9201239 +@contextlib.contextmanager +def mockenv_context(*remove, **update): + """ + Temporarily updates the `os.environ` dictionary in-place. Similar to mockenv + + The `os.environ` dictionary is updated in-place so that the modification is sure to work in all situations. + + Args: + remove: Environment variables to remove. + update: Dictionary of environment variables and values to add/update. + """ + env = os.environ + update = update or {} + remove = remove or [] + + # List of environment variables being updated or removed. + stomped = (set(update.keys()) | set(remove)) & set(env.keys()) + # Environment variables and values to restore on exit. + update_after = {k: env[k] for k in stomped} + # Environment variables and values to remove on exit. + remove_after = frozenset(k for k in update if k not in env) + + try: + env.update(update) + [env.pop(k, None) for k in remove] + yield + finally: + env.update(update_after) + [env.pop(k) for k in remove_after] + + +# --- pytest conf functions --- # + +# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once +pytest_opt_registered = {} + + +def pytest_addoption_shared(parser): + """ + This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there. + + It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest` + option. + + """ + option = "--make-reports" + if option not in pytest_opt_registered: + parser.addoption( + option, + action="store", + default=False, + help="generate report files. The value of this option is used as a prefix to report names", + ) + pytest_opt_registered[option] = 1 + + +def pytest_terminal_summary_main(tr, id): + """ + Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current + directory. The report files are prefixed with the test suite name. + + This function emulates --duration and -rA pytest arguments. + + This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined + there. + + Args: + - tr: `terminalreporter` passed from `conftest.py` + - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is + needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other. + + NB: this functions taps into a private _pytest API and while unlikely, it could break should pytest do internal + changes - also it calls default internal methods of terminalreporter which can be hijacked by various `pytest-` + plugins and interfere. + + """ + from _pytest.config import create_terminal_writer + + if not len(id): + id = "tests" + + config = tr.config + orig_writer = config.get_terminal_writer() + orig_tbstyle = config.option.tbstyle + orig_reportchars = tr.reportchars + + dir = f"reports/{id}" + Path(dir).mkdir(parents=True, exist_ok=True) + report_files = { + k: f"{dir}/{k}.txt" + for k in [ + "durations", + "errors", + "failures_long", + "failures_short", + "failures_line", + "passes", + "stats", + "summary_short", + "warnings", + ] + } + + # custom durations report + # note: there is no need to call pytest --durations=XX to get this separate report + # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66 + dlist = [] + for replist in tr.stats.values(): + for rep in replist: + if hasattr(rep, "duration"): + dlist.append(rep) + if dlist: + dlist.sort(key=lambda x: x.duration, reverse=True) + with open(report_files["durations"], "w") as f: + durations_min = 0.05 # sec + f.write("slowest durations\n") + for i, rep in enumerate(dlist): + if rep.duration < durations_min: + f.write(f"{len(dlist) - i} durations < {durations_min} secs were omitted") + break + f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n") + + def summary_failures_short(tr): + # expecting that the reports were --tb=long (default) so we chop them off here to the last frame + reports = tr.getreports("failed") + if not reports: + return + tr.write_sep("=", "FAILURES SHORT STACK") + for rep in reports: + msg = tr._getfailureheadline(rep) + tr.write_sep("_", msg, red=True, bold=True) + # chop off the optional leading extra frames, leaving only the last one + longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.MULTILINE | re.DOTALL) + tr._tw.line(longrepr) + # note: not printing out any rep.sections to keep the report short + + # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each + # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814 + # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g. + # pytest-instafail does that) + + # report failures with line/short/long styles + config.option.tbstyle = "auto" # full tb + with open(report_files["failures_long"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_failures() + + # config.option.tbstyle = "short" # short tb + with open(report_files["failures_short"], "w") as f: + tr._tw = create_terminal_writer(config, f) + summary_failures_short(tr) + + config.option.tbstyle = "line" # one line per error + with open(report_files["failures_line"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_failures() + + with open(report_files["errors"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_errors() + + with open(report_files["warnings"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_warnings() # normal warnings + tr.summary_warnings() # final warnings + + tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary()) + + # Skip the `passes` report, as it starts to take more than 5 minutes, and sometimes it timeouts on CircleCI if it + # takes > 10 minutes (as this part doesn't generate any output on the terminal). + # (also, it seems there is no useful information in this report, and we rarely need to read it) + # with open(report_files["passes"], "w") as f: + # tr._tw = create_terminal_writer(config, f) + # tr.summary_passes() + + with open(report_files["summary_short"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.short_test_summary() + + with open(report_files["stats"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_stats() + + # restore: + tr._tw = orig_writer + tr.reportchars = orig_reportchars + config.option.tbstyle = orig_tbstyle + + +# --- distributed testing functions --- # + +# adapted from https://stackoverflow.com/a/59041913/9201239 +import asyncio # noqa + + +class _RunOutput: + def __init__(self, returncode, stdout, stderr): + self.returncode = returncode + self.stdout = stdout + self.stderr = stderr + + +async def _read_stream(stream, callback): + while True: + line = await stream.readline() + if line: + callback(line) + else: + break + + +async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput: + if echo: + print("\nRunning: ", " ".join(cmd)) + + p = await asyncio.create_subprocess_exec( + cmd[0], + *cmd[1:], + stdin=stdin, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + ) + + # note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe + # https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait + # + # If it starts hanging, will need to switch to the following code. The problem is that no data + # will be seen until it's done and if it hangs for example there will be no debug info. + # out, err = await p.communicate() + # return _RunOutput(p.returncode, out, err) + + out = [] + err = [] + + def tee(line, sink, pipe, label=""): + line = line.decode("utf-8").rstrip() + sink.append(line) + if not quiet: + print(label, line, file=pipe) + + # XXX: the timeout doesn't seem to make any difference here + await asyncio.wait( + [ + asyncio.create_task(_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:"))), + asyncio.create_task(_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:"))), + ], + timeout=timeout, + ) + return _RunOutput(await p.wait(), out, err) + + +def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput: + loop = asyncio.get_event_loop() + result = loop.run_until_complete( + _stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo) + ) + + cmd_str = " ".join(cmd) + if result.returncode > 0: + stderr = "\n".join(result.stderr) + raise RuntimeError( + f"'{cmd_str}' failed with returncode {result.returncode}\n\n" + f"The combined stderr from workers follows:\n{stderr}" + ) + + # check that the subprocess actually did run and produced some output, should the test rely on + # the remote side to do the testing + if not result.stdout and not result.stderr: + raise RuntimeError(f"'{cmd_str}' produced no output.") + + return result + + +def pytest_xdist_worker_id(): + """ + Returns an int value of worker's numerical id under `pytest-xdist`'s concurrent workers `pytest -n N` regime, or 0 + if `-n 1` or `pytest-xdist` isn't being used. + """ + worker = os.environ.get("PYTEST_XDIST_WORKER", "gw0") + worker = re.sub(r"^gw", "", worker, 0, re.MULTILINE) + return int(worker) + + +def get_torch_dist_unique_port(): + """ + Returns a port number that can be fed to `torch.distributed.launch`'s `--master_port` argument. + + Under `pytest-xdist` it adds a delta number based on a worker id so that concurrent tests don't try to use the same + port at once. + """ + port = 29500 + uniq_delta = pytest_xdist_worker_id() + return port + uniq_delta + + +def nested_simplify(obj, decimals=3): + """ + Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test + within tests. + """ + import numpy as np + + if isinstance(obj, list): + return [nested_simplify(item, decimals) for item in obj] + if isinstance(obj, tuple): + return tuple(nested_simplify(item, decimals) for item in obj) + elif isinstance(obj, np.ndarray): + return nested_simplify(obj.tolist()) + elif isinstance(obj, Mapping): + return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()} + elif isinstance(obj, (str, int, np.int64)) or obj is None: + return obj + elif is_torch_available() and isinstance(obj, torch.Tensor): + return nested_simplify(obj.tolist(), decimals) + elif is_tf_available() and tf.is_tensor(obj): + return nested_simplify(obj.numpy().tolist()) + elif isinstance(obj, float): + return round(obj, decimals) + elif isinstance(obj, (np.int32, np.float32, np.float16)): + return nested_simplify(obj.item(), decimals) + else: + raise Exception(f"Not supported: {type(obj)}") + + +def check_json_file_has_correct_format(file_path): + with open(file_path) as f: + lines = f.readlines() + if len(lines) == 1: + # length can only be 1 if dict is empty + assert lines[0] == "{}" + else: + # otherwise make sure json has correct format (at least 3 lines) + assert len(lines) >= 3 + # each key one line, ident should be 2, min length is 3 + assert lines[0].strip() == "{" + for line in lines[1:-1]: + left_indent = len(lines[1]) - len(lines[1].lstrip()) + assert left_indent == 2 + assert lines[-1].strip() == "}" + + +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + + +# These utils relate to ensuring the right error message is received when running scripts +class SubprocessCallException(Exception): + pass + + +def run_command(command: list[str], return_stdout=False): + """ + Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture + if an error occurred while running `command` + """ + try: + output = subprocess.check_output(command, stderr=subprocess.STDOUT) + if return_stdout: + if hasattr(output, "decode"): + output = output.decode("utf-8") + return output + except subprocess.CalledProcessError as e: + raise SubprocessCallException( + f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}" + ) from e + + +class RequestCounter: + """ + Helper class that will count all requests made online. + + Might not be robust if urllib3 changes its logging format but should be good enough for us. + + Usage: + ```py + with RequestCounter() as counter: + _ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") + assert counter["GET"] == 0 + assert counter["HEAD"] == 1 + assert counter.total_calls == 1 + ``` + """ + + def __enter__(self): + self._counter = defaultdict(int) + self._thread_id = threading.get_ident() + self._extra_info = [] + + def patched_with_thread_info(func): + def wrap(*args, **kwargs): + self._extra_info.append(threading.get_ident()) + return func(*args, **kwargs) + + return wrap + + self.patcher = patch.object( + urllib3.connectionpool.log, "debug", side_effect=patched_with_thread_info(urllib3.connectionpool.log.debug) + ) + self.mock = self.patcher.start() + return self + + def __exit__(self, *args, **kwargs) -> None: + assert len(self.mock.call_args_list) == len(self._extra_info) + for thread_id, call in zip(self._extra_info, self.mock.call_args_list): + if thread_id != self._thread_id: + continue + # code 307: the URL being requested by the user has moved to a temporary location + if call.args[-2] == 307: + continue + log = call.args[0] % call.args[1:] + for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"): + if method in log: + self._counter[method] += 1 + break + self.patcher.stop() + + def __getitem__(self, key: str) -> int: + return self._counter[key] + + @property + def total_calls(self) -> int: + return sum(self._counter.values()) + + +def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None): + """ + To decorate flaky tests. They will be retried on failures. + + Please note that our push tests use `pytest-rerunfailures`, which prompts the CI to rerun certain types of + failed tests. More specifically, if the test exception contains any substring in `FLAKY_TEST_FAILURE_PATTERNS` + (in `.circleci/create_circleci_config.py`), it will be rerun. If you find a recurrent pattern of failures, + expand `FLAKY_TEST_FAILURE_PATTERNS` in our CI configuration instead of using `is_flaky`. + + Args: + max_attempts (`int`, *optional*, defaults to 5): + The maximum number of attempts to retry the flaky test. + wait_before_retry (`float`, *optional*): + If provided, will wait that number of seconds before retrying the test. + description (`str`, *optional*): + A string to describe the situation (what / where / why is flaky, link to GH issue/PR comments, errors, + etc.) + """ + + def decorator(test_func_ref): + @functools.wraps(test_func_ref) + def wrapper(*args, **kwargs): + retry_count = 1 + + while retry_count < max_attempts: + try: + return test_func_ref(*args, **kwargs) + + except Exception as err: + logger.error(f"Test failed with {err} at try {retry_count}/{max_attempts}.") + if wait_before_retry is not None: + time.sleep(wait_before_retry) + retry_count += 1 + + return test_func_ref(*args, **kwargs) + + return unittest.skipUnless(_run_flaky_tests, "test is flaky")(wrapper) + + return decorator + + +def hub_retry(max_attempts: int = 5, wait_before_retry: Optional[float] = 2): + """ + To decorate tests that download from the Hub. They can fail due to a + variety of network issues such as timeouts, connection resets, etc. + + Args: + max_attempts (`int`, *optional*, defaults to 5): + The maximum number of attempts to retry the flaky test. + wait_before_retry (`float`, *optional*, defaults to 2): + If provided, will wait that number of seconds before retrying the test. + """ + + def decorator(test_func_ref): + @functools.wraps(test_func_ref) + def wrapper(*args, **kwargs): + retry_count = 1 + + while retry_count < max_attempts: + try: + return test_func_ref(*args, **kwargs) + # We catch all exceptions related to network issues from requests + except ( + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + requests.exceptions.ReadTimeout, + requests.exceptions.HTTPError, + requests.exceptions.RequestException, + ) as err: + logger.error( + f"Test failed with {err} at try {retry_count}/{max_attempts} as it couldn't connect to the specified Hub repository." + ) + if wait_before_retry is not None: + time.sleep(wait_before_retry) + retry_count += 1 + + return test_func_ref(*args, **kwargs) + + return wrapper + + return decorator + + +def run_first(test_case): + """ + Decorator marking a test with order(1). When pytest-order plugin is installed, tests marked with this decorator + are guaranteed to run first. + + This is especially useful in some test settings like on a Gaudi instance where a Gaudi device can only be used by a + single process at a time. So we make sure all tests that run in a subprocess are launched first, to avoid device + allocation conflicts. + """ + import pytest + + return pytest.mark.order(1)(test_case) + + +def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): + """ + To run a test in a subprocess. In particular, this can avoid (GPU) memory issue. + + Args: + test_case (`unittest.TestCase`): + The test that will run `target_func`. + target_func (`Callable`): + The function implementing the actual testing logic. + inputs (`dict`, *optional*, defaults to `None`): + The inputs that will be passed to `target_func` through an (input) queue. + timeout (`int`, *optional*, defaults to `None`): + The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env. + variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`. + """ + if timeout is None: + timeout = int(os.environ.get("PYTEST_TIMEOUT", "600")) + + start_methohd = "spawn" + ctx = multiprocessing.get_context(start_methohd) + + input_queue = ctx.Queue(1) + output_queue = ctx.JoinableQueue(1) + + # We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle. + input_queue.put(inputs, timeout=timeout) + + process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout)) + process.start() + # Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents + # the test to exit properly. + try: + results = output_queue.get(timeout=timeout) + output_queue.task_done() + except Exception as e: + process.terminate() + test_case.fail(e) + process.join(timeout=timeout) + + if results["error"] is not None: + test_case.fail(f"{results['error']}") + + +def run_test_using_subprocess(func): + """ + To decorate a test to run in a subprocess using the `subprocess` module. This could avoid potential GPU memory + issues (GPU OOM or a test that causes many subsequential failing with `CUDA error: device-side assert triggered`). + """ + import pytest + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if os.getenv("_INSIDE_SUB_PROCESS", None) == "1": + func(*args, **kwargs) + else: + test = " ".join(os.environ.get("PYTEST_CURRENT_TEST").split(" ")[:-1]) + try: + env = copy.deepcopy(os.environ) + env["_INSIDE_SUB_PROCESS"] = "1" + # This prevents the entries in `short test summary info` given by the subprocess being truncated. so the + # full information can be passed to the parent pytest process. + # See: https://docs.pytest.org/en/stable/explanation/ci.html + env["CI"] = "true" + + # If not subclass of `unitTest.TestCase` and `pytestconfig` is used: try to grab and use the arguments + if "pytestconfig" in kwargs: + command = list(kwargs["pytestconfig"].invocation_params.args) + for idx, x in enumerate(command): + if x in kwargs["pytestconfig"].args: + test = test.split("::")[1:] + command[idx] = "::".join([f"{func.__globals__['__file__']}"] + test) + command = [f"{sys.executable}", "-m", "pytest"] + command + command = [x for x in command if x != "--no-summary"] + # Otherwise, simply run the test with no option at all + else: + command = [f"{sys.executable}", "-m", "pytest", f"{test}"] + + subprocess.run(command, env=env, check=True, capture_output=True) + except subprocess.CalledProcessError as e: + exception_message = e.stdout.decode() + lines = exception_message.split("\n") + # Add a first line with more informative information instead of just `= test session starts =`. + # This makes the `short test summary info` section more useful. + if "= test session starts =" in lines[0]: + text = "" + for line in lines[1:]: + if line.startswith("FAILED "): + text = line[len("FAILED ") :] + text = "".join(text.split(" - ")[1:]) + elif line.startswith("=") and line.endswith("=") and " failed in " in line: + break + elif len(text) > 0: + text += f"\n{line}" + text = "(subprocess) " + text + lines = [text] + lines + exception_message = "\n".join(lines) + raise pytest.fail(exception_message, pytrace=False) + + return wrapper + + +""" +The following contains utils to run the documentation tests without having to overwrite any files. + +The `preprocess_string` function adds `# doctest: +IGNORE_RESULT` markers on the fly anywhere a `load_dataset` call is +made as a print would otherwise fail the corresponding line. + +To skip cuda tests, make sure to call `SKIP_CUDA_DOCTEST=1 pytest --doctest-modules +""" + + +def preprocess_string(string, skip_cuda_tests): + """Prepare a docstring or a `.md` file to be run by doctest. + + The argument `string` would be the whole file content if it is a `.md` file. For a python file, it would be one of + its docstring. In each case, it may contain multiple python code examples. If `skip_cuda_tests` is `True` and a + cuda stuff is detective (with a heuristic), this method will return an empty string so no doctest will be run for + `string`. + """ + codeblock_pattern = r"(```(?:python|py)\s*\n\s*>>> )(.*?```)" + codeblocks = re.split(codeblock_pattern, string, flags=re.DOTALL) + is_cuda_found = False + for i, codeblock in enumerate(codeblocks): + if "load_dataset(" in codeblock and "# doctest: +IGNORE_RESULT" not in codeblock: + codeblocks[i] = re.sub(r"(>>> .*load_dataset\(.*)", r"\1 # doctest: +IGNORE_RESULT", codeblock) + if ( + (">>>" in codeblock or "..." in codeblock) + and re.search(r"cuda|to\(0\)|device=0", codeblock) + and skip_cuda_tests + ): + is_cuda_found = True + break + + modified_string = "" + if not is_cuda_found: + modified_string = "".join(codeblocks) + + return modified_string + + +class HfDocTestParser(doctest.DocTestParser): + """ + Overwrites the DocTestParser from doctest to properly parse the codeblocks that are formatted with black. This + means that there are no extra lines at the end of our snippets. The `# doctest: +IGNORE_RESULT` marker is also + added anywhere a `load_dataset` call is made as a print would otherwise fail the corresponding line. + + Tests involving cuda are skipped base on a naive pattern that should be updated if it is not enough. + """ + + # This regular expression is used to find doctest examples in a + # string. It defines three groups: `source` is the source code + # (including leading indentation and prompts); `indent` is the + # indentation of the first (PS1) line of the source code; and + # `want` is the expected output (including leading indentation). + # fmt: off + _EXAMPLE_RE = re.compile(r''' + # Source consists of a PS1 line followed by zero or more PS2 lines. + (?P + (?:^(?P [ ]*) >>> .*) # PS1 line + (?:\n [ ]* \.\.\. .*)*) # PS2 lines + \n? + # Want consists of any non-blank lines that do not start with PS1. + (?P (?:(?![ ]*$) # Not a blank line + (?![ ]*>>>) # Not a line starting with PS1 + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + (?:(?!```).)* # Match any character except '`' until a '```' is found (this is specific to HF because black removes the last line) + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + (?:\n|$) # Match a new line or end of string + )*) + ''', re.MULTILINE | re.VERBOSE + ) + # fmt: on + + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + skip_cuda_tests: bool = bool(os.environ.get("SKIP_CUDA_DOCTEST", "0")) + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + + def parse(self, string, name=""): + """ + Overwrites the `parse` method to incorporate a skip for CUDA tests, and remove logs and dataset prints before + calling `super().parse` + """ + string = preprocess_string(string, self.skip_cuda_tests) + return super().parse(string, name) + + +class HfDoctestModule(Module): + """ + Overwrites the `DoctestModule` of the pytest package to make sure the HFDocTestParser is used when discovering + tests. + """ + + def collect(self) -> Iterable[DoctestItem]: + class MockAwareDocTestFinder(doctest.DocTestFinder): + """A hackish doctest finder that overrides stdlib internals to fix a stdlib bug. + + https://github.com/pytest-dev/pytest/issues/3456 https://bugs.python.org/issue25532 + """ + + def _find_lineno(self, obj, source_lines): + """Doctest code does not take into account `@property`, this + is a hackish way to fix it. https://bugs.python.org/issue17446 + + Wrapped Doctests will need to be unwrapped so the correct line number is returned. This will be + reported upstream. #8796 + """ + if isinstance(obj, property): + obj = getattr(obj, "fget", obj) + + if hasattr(obj, "__wrapped__"): + # Get the main obj in case of it being wrapped + obj = inspect.unwrap(obj) + + # Type ignored because this is a private function. + return super()._find_lineno( # type:ignore[misc] + obj, + source_lines, + ) + + def _find(self, tests, obj, name, module, source_lines, globs, seen) -> None: + if _is_mocked(obj): + return + with _patch_unwrap_mock_aware(): + # Type ignored because this is a private function. + super()._find( # type:ignore[misc] + tests, obj, name, module, source_lines, globs, seen + ) + + if self.path.name == "conftest.py": + module = self.config.pluginmanager._importconftest( + self.path, + self.config.getoption("importmode"), + rootpath=self.config.rootpath, + ) + else: + try: + module = import_path( + self.path, + root=self.config.rootpath, + mode=self.config.getoption("importmode"), + ) + except ImportError: + if self.config.getvalue("doctest_ignore_import_errors"): + skip("unable to import module %r" % self.path) + else: + raise + + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + finder = MockAwareDocTestFinder(parser=HfDocTestParser()) + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + optionflags = get_optionflags(self) + runner = _get_runner( + verbose=False, + optionflags=optionflags, + checker=_get_checker(), + continue_on_failure=_get_continue_on_failure(self.config), + ) + for test in finder.find(module, module.__name__): + if test.examples: # skip empty doctests and cuda + yield DoctestItem.from_parent(self, name=test.name, runner=runner, dtest=test) + + +def _device_agnostic_dispatch(device: str, dispatch_table: dict[str, Callable], *args, **kwargs): + if device not in dispatch_table: + if not callable(dispatch_table["default"]): + return dispatch_table["default"] + + return dispatch_table["default"](*args, **kwargs) + + fn = dispatch_table[device] + + # Some device agnostic functions return values or None, will return then directly. + if not callable(fn): + return fn + + return fn(*args, **kwargs) + + +if is_torch_available(): + # Mappings from device names to callable functions to support device agnostic + # testing. + BACKEND_MANUAL_SEED = { + "cuda": torch.cuda.manual_seed, + "cpu": torch.manual_seed, + "default": torch.manual_seed, + } + BACKEND_EMPTY_CACHE = { + "cuda": torch.cuda.empty_cache, + "cpu": None, + "default": None, + } + BACKEND_DEVICE_COUNT = { + "cuda": torch.cuda.device_count, + "cpu": lambda: 0, + "default": lambda: 1, + } + BACKEND_RESET_MAX_MEMORY_ALLOCATED = { + "cuda": torch.cuda.reset_max_memory_allocated, + "cpu": None, + "default": None, + } + BACKEND_MAX_MEMORY_ALLOCATED = { + "cuda": torch.cuda.max_memory_allocated, + "cpu": 0, + "default": 0, + } + BACKEND_RESET_PEAK_MEMORY_STATS = { + "cuda": torch.cuda.reset_peak_memory_stats, + "cpu": None, + "default": None, + } + BACKEND_MEMORY_ALLOCATED = { + "cuda": torch.cuda.memory_allocated, + "cpu": 0, + "default": 0, + } + BACKEND_SYNCHRONIZE = { + "cuda": torch.cuda.synchronize, + "cpu": None, + "default": None, + } + BACKEND_TORCH_ACCELERATOR_MODULE = { + "cuda": torch.cuda, + "cpu": None, + "default": None, + } +else: + BACKEND_MANUAL_SEED = {"default": None} + BACKEND_EMPTY_CACHE = {"default": None} + BACKEND_DEVICE_COUNT = {"default": lambda: 0} + BACKEND_RESET_MAX_MEMORY_ALLOCATED = {"default": None} + BACKEND_RESET_PEAK_MEMORY_STATS = {"default": None} + BACKEND_MAX_MEMORY_ALLOCATED = {"default": 0} + BACKEND_MEMORY_ALLOCATED = {"default": 0} + BACKEND_SYNCHRONIZE = {"default": None} + BACKEND_TORCH_ACCELERATOR_MODULE = {"default": None} + + +if is_torch_hpu_available(): + BACKEND_MANUAL_SEED["hpu"] = torch.hpu.manual_seed + BACKEND_DEVICE_COUNT["hpu"] = torch.hpu.device_count + BACKEND_TORCH_ACCELERATOR_MODULE["hpu"] = torch.hpu + +if is_torch_mlu_available(): + BACKEND_EMPTY_CACHE["mlu"] = torch.mlu.empty_cache + BACKEND_MANUAL_SEED["mlu"] = torch.mlu.manual_seed + BACKEND_DEVICE_COUNT["mlu"] = torch.mlu.device_count + BACKEND_TORCH_ACCELERATOR_MODULE["mlu"] = torch.mlu + +if is_torch_npu_available(): + BACKEND_EMPTY_CACHE["npu"] = torch.npu.empty_cache + BACKEND_MANUAL_SEED["npu"] = torch.npu.manual_seed + BACKEND_DEVICE_COUNT["npu"] = torch.npu.device_count + BACKEND_TORCH_ACCELERATOR_MODULE["npu"] = torch.npu + +if is_torch_xpu_available(): + BACKEND_EMPTY_CACHE["xpu"] = torch.xpu.empty_cache + BACKEND_MANUAL_SEED["xpu"] = torch.xpu.manual_seed + BACKEND_DEVICE_COUNT["xpu"] = torch.xpu.device_count + BACKEND_RESET_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.reset_peak_memory_stats + BACKEND_RESET_PEAK_MEMORY_STATS["xpu"] = torch.xpu.reset_peak_memory_stats + BACKEND_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.max_memory_allocated + BACKEND_MEMORY_ALLOCATED["xpu"] = torch.xpu.memory_allocated + BACKEND_SYNCHRONIZE["xpu"] = torch.xpu.synchronize + BACKEND_TORCH_ACCELERATOR_MODULE["xpu"] = torch.xpu + + +if is_torch_xla_available(): + BACKEND_EMPTY_CACHE["xla"] = torch.cuda.empty_cache + BACKEND_MANUAL_SEED["xla"] = torch.cuda.manual_seed + BACKEND_DEVICE_COUNT["xla"] = torch.cuda.device_count + + +def backend_manual_seed(device: str, seed: int): + return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed) + + +def backend_empty_cache(device: str): + return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE) + + +def backend_device_count(device: str): + return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT) + + +def backend_reset_max_memory_allocated(device: str): + return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED) + + +def backend_reset_peak_memory_stats(device: str): + return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS) + + +def backend_max_memory_allocated(device: str): + return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED) + + +def backend_memory_allocated(device: str): + return _device_agnostic_dispatch(device, BACKEND_MEMORY_ALLOCATED) + + +def backend_synchronize(device: str): + return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE) + + +def backend_torch_accelerator_module(device: str): + return _device_agnostic_dispatch(device, BACKEND_TORCH_ACCELERATOR_MODULE) + + +if is_torch_available(): + # If `TRANSFORMERS_TEST_DEVICE_SPEC` is enabled we need to import extra entries + # into device to function mappings. + if "TRANSFORMERS_TEST_DEVICE_SPEC" in os.environ: + device_spec_path = os.environ["TRANSFORMERS_TEST_DEVICE_SPEC"] + if not Path(device_spec_path).is_file(): + raise ValueError( + f"Specified path to device spec file is not a file or not found. Received '{device_spec_path}" + ) + + # Try to strip extension for later import – also verifies we are importing a + # python file. + device_spec_dir, _ = os.path.split(os.path.realpath(device_spec_path)) + sys.path.append(device_spec_dir) + try: + import_name = device_spec_path[: device_spec_path.index(".py")] + except ValueError as e: + raise ValueError(f"Provided device spec file was not a Python file! Received '{device_spec_path}") from e + + device_spec_module = importlib.import_module(import_name) + + # Imported file must contain `DEVICE_NAME`. If it doesn't, terminate early. + try: + device_name = device_spec_module.DEVICE_NAME + except AttributeError as e: + raise AttributeError("Device spec file did not contain `DEVICE_NAME`") from e + + if "TRANSFORMERS_TEST_DEVICE" in os.environ and torch_device != device_name: + msg = f"Mismatch between environment variable `TRANSFORMERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n" + msg += "Either unset `TRANSFORMERS_TEST_DEVICE` or ensure it matches device spec name." + raise ValueError(msg) + + torch_device = device_name + + def update_mapping_from_spec(device_fn_dict: dict[str, Callable], attribute_name: str): + try: + # Try to import the function directly + spec_fn = getattr(device_spec_module, attribute_name) + device_fn_dict[torch_device] = spec_fn + except AttributeError as e: + # If the function doesn't exist, and there is no default, throw an error + if "default" not in device_fn_dict: + raise AttributeError( + f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found." + ) from e + + # Add one entry here for each `BACKEND_*` dictionary. + update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN") + update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN") + update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN") + + +def compare_pipeline_output_to_hub_spec(output, hub_spec): + missing_keys = [] + unexpected_keys = [] + all_field_names = {field.name for field in fields(hub_spec)} + matching_keys = sorted([key for key in output if key in all_field_names]) + + # Fields with a MISSING default are required and must be in the output + for field in fields(hub_spec): + if field.default is MISSING and field.name not in output: + missing_keys.append(field.name) + + # All output keys must match either a required or optional field in the Hub spec + for output_key in output: + if output_key not in all_field_names: + unexpected_keys.append(output_key) + + if missing_keys or unexpected_keys: + error = ["Pipeline output does not match Hub spec!"] + if matching_keys: + error.append(f"Matching keys: {matching_keys}") + if missing_keys: + error.append(f"Missing required keys in pipeline output: {missing_keys}") + if unexpected_keys: + error.append(f"Keys in pipeline output that are not in Hub spec: {unexpected_keys}") + raise KeyError("\n".join(error)) + + +@require_torch +def cleanup(device: str, gc_collect=False): + if gc_collect: + gc.collect() + backend_empty_cache(device) + torch._dynamo.reset() + + +# Type definition of key used in `Expectations` class. +DeviceProperties = tuple[Optional[str], Optional[int], Optional[int]] +# Helper type. Makes creating instances of `Expectations` smoother. +PackedDeviceProperties = tuple[Optional[str], Union[None, int, tuple[int, int]]] + + +@cache +def get_device_properties() -> DeviceProperties: + """ + Get environment device properties. + """ + if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: + import torch + + major, minor = torch.cuda.get_device_capability() + if IS_ROCM_SYSTEM: + return ("rocm", major, minor) + else: + return ("cuda", major, minor) + elif IS_XPU_SYSTEM: + import torch + + # To get more info of the architecture meaning and bit allocation, refer to https://github.com/intel/llvm/blob/sycl/sycl/include/sycl/ext/oneapi/experimental/device_architecture.def + arch = torch.xpu.get_device_capability()["architecture"] + gen_mask = 0x000000FF00000000 + gen = (arch & gen_mask) >> 32 + return ("xpu", gen, None) + else: + return (torch_device, None, None) + + +def unpack_device_properties( + properties: Optional[PackedDeviceProperties] = None, +) -> DeviceProperties: + """ + Unpack a `PackedDeviceProperties` tuple into consistently formatted `DeviceProperties` tuple. If properties is None, it is fetched. + """ + if properties is None: + return get_device_properties() + device_type, major_minor = properties + if major_minor is None: + major, minor = None, None + elif isinstance(major_minor, int): + major, minor = major_minor, None + else: + major, minor = major_minor + return device_type, major, minor + + +class Expectations(UserDict[PackedDeviceProperties, Any]): + def get_expectation(self) -> Any: + """ + Find best matching expectation based on environment device properties. We look at device_type, major and minor + versions of the drivers. Expectations are stored as a dictionary with keys of the form + (device_type, (major, minor)). If the major and minor versions are not provided, we use None. + """ + return self.find_expectation(get_device_properties()) + + def unpacked(self) -> list[tuple[DeviceProperties, Any]]: + return [(unpack_device_properties(k), v) for k, v in self.data.items()] + + @staticmethod + def is_default(expectation_key: PackedDeviceProperties) -> bool: + """ + This function returns True if the expectation_key is the Default expectation (None, None). + When an Expectation dict contains a Default value, it is generally because the test existed before Expectations. + When we modify a test to use Expectations for a specific hardware, we don't want to affect the tests on other + hardwares. Thus we set the previous value as the Default expectation with key (None, None) and add a value for + the specific hardware with key (hardware_type, (major, minor)). + """ + return all(p is None for p in expectation_key) + + @staticmethod + def score(properties: DeviceProperties, other: DeviceProperties) -> float: + """ + Returns score indicating how similar two instances of the `Properties` tuple are. + Rules are as follows: + * Matching `type` adds one point, semi-matching `type` adds 0.1 point (e.g. cuda and rocm). + * If types match, matching `major` adds another point, and then matching `minor` adds another. + * The Default expectation (None, None) is worth 0.5 point, which is better than semi-matching. More on this + in the `is_default` function. + """ + device_type, major, minor = properties + other_device_type, other_major, other_minor = other + + score = 0 + # Matching device type, maybe major and minor + if device_type is not None and device_type == other_device_type: + score += 1 + if major is not None and major == other_major: + score += 1 + if minor is not None and minor == other_minor: + score += 1 + # Semi-matching device type, which carries less importance than the default expectation + elif device_type in ["cuda", "rocm"] and other_device_type in ["cuda", "rocm"]: + score = 0.1 + + # Default expectation + if Expectations.is_default(other): + score = 0.5 + + return score + + def find_expectation(self, properties: DeviceProperties = (None, None, None)) -> Any: + """ + Find best matching expectation based on provided device properties. We score each expectation, and to + distinguish between expectations with the same score, we use the major and minor version numbers, prioritizing + most recent versions. + """ + (result_key, result) = max( + self.unpacked(), + key=lambda x: ( + Expectations.score(properties, x[0]), # x[0] is a device properties tuple (device_type, major, minor) + x[0][1] if x[0][1] is not None else -1, # This key is the major version, -1 if major is None + x[0][2] if x[0][2] is not None else -1, # This key is the minor version, -1 if minor is None + ), + ) + + if Expectations.score(properties, result_key) == 0: + raise ValueError(f"No matching expectation found for {properties}") + + return result + + def __repr__(self): + return f"{self.data}" + + +def patch_torch_compile_force_graph(): + """ + Patch `torch.compile` to always use `fullgraph=True`. + + This is useful when some `torch.compile` tests are running with `fullgraph=False` and we want to be able to run + them with `fullgraph=True` in some occasion (without introducing new tests) to make sure there is no graph break. + + After PR #40137, `CompileConfig.fullgraph` is `False` by default, this patch is necessary. + """ + + force_fullgraph = os.environ.get("TORCH_COMPILE_FORCE_FULLGRAPH", "") + force_fullgraph = force_fullgraph.lower() in ("yes", "true", "on", "t", "y", "1") + + if force_fullgraph: + import torch + + orig_method = torch.compile + + def patched(*args, **kwargs): + # In `torch_compile`, all arguments except `model` is keyword only argument. + kwargs["fullgraph"] = True + return orig_method(*args, **kwargs) + + torch.compile = patched + + +def _get_test_info(): + """ + Collect some information about the current test. + + For example, test full name, line number, stack, traceback, etc. + """ + + full_test_name = os.environ.get("PYTEST_CURRENT_TEST", "").split(" ")[0] + test_file, test_class, test_name = full_test_name.split("::") + + # from the most recent frame to the top frame + stack_from_inspect = inspect.stack() + # but visit from the top frame to the most recent frame + + actual_test_file, _actual_test_class = test_file, test_class + test_frame, test_obj, test_method = None, None, None + for frame in reversed(stack_from_inspect): + # if test_file in str(frame).replace(r"\\", "/"): + # check frame's function + if it has `self` as locals; double check if self has the (function) name + # TODO: Question: How about expanded? + if ( + frame.function == test_name + and "self" in frame.frame.f_locals + and hasattr(frame.frame.f_locals["self"], test_name) + ): + # if test_name == frame.frame.f_locals["self"]._testMethodName: + test_frame = frame + # The test instance + test_obj = frame.frame.f_locals["self"] + # TODO: Do we get the (relative?) path or it's just a file name? + # TODO: Does `test_obj` always have `tearDown` object? + actual_test_file = frame.filename + # TODO: check `test_method` will work used at the several places! + test_method = getattr(test_obj, test_name) + break + + if test_frame is not None: + line_number = test_frame.lineno + + # The frame of `patched` being called (the one and the only one calling `_get_test_info`) + # This is used to get the original method being patched in order to get the context. + frame_of_patched_obj = None + + captured_frames = [] + to_capture = False + # From the most outer (i.e. python's `runpy.py`) frame to most inner frame (i.e. the frame of this method) + # Between `the test method being called` and `before entering `patched``. + for frame in reversed(stack_from_inspect): + if ( + frame.function == test_name + and "self" in frame.frame.f_locals + and hasattr(frame.frame.f_locals["self"], test_name) + ): + to_capture = True + # TODO: check simply with the name is not robust. + elif "patched" == frame.frame.f_code.co_name: + frame_of_patched_obj = frame + to_capture = False + break + if to_capture: + captured_frames.append(frame) + + tb_next = None + for frame_info in reversed(captured_frames): + tb = types.TracebackType(tb_next, frame_info.frame, frame_info.frame.f_lasti, frame_info.frame.f_lineno) + tb_next = tb + test_traceback = tb + + origin_method_being_patched = frame_of_patched_obj.frame.f_locals["orig_method"] + + # An iterable of type `traceback.StackSummary` with each element of type `FrameSummary` + stack = traceback.extract_stack() + # The frame which calls `the original method being patched` + caller_frame = None + # From the most inner (i.e. recent) frame to the most outer frame + for frame in reversed(stack): + if origin_method_being_patched.__name__ in frame.line: + caller_frame = frame + + caller_path = os.path.relpath(caller_frame.filename) + caller_lineno = caller_frame.lineno + + test_lineno = line_number + + # Get the code context in the test function/method. + from _pytest._code.source import Source + + with open(actual_test_file) as fp: + s = fp.read() + source = Source(s) + test_code_context = "\n".join(source.getstatement(test_lineno - 1).lines) + + # Get the code context in the caller (to the patched function/method). + with open(caller_path) as fp: + s = fp.read() + source = Source(s) + caller_code_context = "\n".join(source.getstatement(caller_lineno - 1).lines) + + test_info = f"test:\n\n{full_test_name}\n\n{'-' * 80}\n\ntest context: {actual_test_file}:{test_lineno}\n\n{test_code_context}" + test_info = f"{test_info}\n\n{'-' * 80}\n\ncaller context: {caller_path}:{caller_lineno}\n\n{caller_code_context}" + + return ( + full_test_name, + test_file, + test_lineno, + test_obj, + test_method, + test_frame, + test_traceback, + test_code_context, + caller_path, + caller_lineno, + caller_code_context, + test_info, + ) + + +def _get_call_arguments(code_context): + """ + Analyze the positional and keyword arguments in a call expression. + + This will extract the expressions of the positional and kwyword arguments, and associate them to the positions and + the keyword arugment names. + """ + + def get_argument_name(node): + """Extract the name/expression from an AST node""" + if isinstance(node, ast.Name): + return node.id + elif isinstance(node, ast.Attribute): + return ast.unparse(node) + elif isinstance(node, ast.Constant): + return repr(node.value) + else: + return ast.unparse(node) + + indent = len(code_context) - len(code_context.lstrip()) + code_context = code_context.replace(" " * indent, "") + + try: + # Parse the line + tree = ast.parse(code_context, mode="eval") + + assert isinstance(tree.body, ast.Call) + call_node = tree.body + + if call_node: + result = { + "positional_args": [], + "keyword_args": {}, + "starargs": None, # *args + "kwargs": None, # **kwargs + } + + # Extract positional arguments + for arg in call_node.args: + arg_name = get_argument_name(arg) + result["positional_args"].append(arg_name) + + # Extract keyword arguments + for keyword in call_node.keywords: + if keyword.arg is None: + # This is **kwargs + result["kwargs"] = get_argument_name(keyword.value) + else: + # Regular keyword argument + arg_name = get_argument_name(keyword.value) + result["keyword_args"][keyword.arg] = arg_name + + return result + + except (SyntaxError, AttributeError) as e: + print(f"Error parsing: {e}") + + return None + + +def _prepare_debugging_info(test_info, info): + """Combine the information about the test and the call information to a patched function/method within it.""" + + info = f"{test_info}\n\n{info}" + p = os.path.join(os.environ.get("_PATCHED_TESTING_METHODS_OUTPUT_DIR", ""), "captured_info.txt") + # TODO (ydshieh): This is not safe when we use pytest-xdist with more than 1 worker. + with open(p, "a") as fp: + fp.write(f"{info}\n\n{'=' * 120}\n\n") + + return info + + +def _patched_tearDown(self, *args, **kwargs): + """Used to report a test that has failures captured and handled by patched functions/methods (without re-raise). + + The patched functions/methods refer to the `patched` defined in `_patch_with_call_info`, which is applied to + `torch.testing.assert_close` and `unittest.case.TestCase.assertEqual`. + + The objective is to avoid a failure being silence after being processed. + + If there is any failure that is not handled by the patched functions/methods, we add custom error message for them + along with the usual pytest failure report. + """ + + # Check for regular failures before clearing: + # when `_patched_tearDown` is called, the current test fails due to an assertion error given by a method being + # patched by `_patch_with_call_info`. The patched method catches such an error and continue running the remaining + # statements within the test. If the test fails with another error not handled by the patched methods, we don't let + # pytest to fail and report it but the original failure (the first one that was processed) instead. + # We still record those failures not handled by the patched methods, and add custom messages along with the usual + # pytest failure report. + regular_failures_info = [] + if hasattr(self, "_outcome") and self._outcome.errors: + for error_entry in self._outcome.errors: + test_instance, (exc_type, exc_obj, exc_tb) = error_entry + # breakpoint() + regular_failures_info.append( + { + "message": f"{str(exc_obj)}\n\n", + "type": exc_type.__name__, + "file": "test_modeling_vit.py", + "line": 237, # get_deepest_frame_line(exc_tb) # Your helper function + } + ) + + # Clear the regular failure (i.e. that is not from any of our patched assertion methods) from pytest's records. + self._outcome.errors.clear() + + # reset back to the original tearDown method, so `_patched_tearDown` won't be run by the subsequent tests if they + # have only test failures that are not handle by the patched methods (or no test failure at all). + orig_tearDown = _patched_tearDown.orig_tearDown + type(self).tearDown = orig_tearDown + + # Call the original tearDown + orig_tearDown(self, *args, **kwargs) + + # Get the failure + test_method = getattr(self, self._testMethodName) + captured_failures = test_method.__func__.captured_failures[id(test_method)] + + # TODO: How could we show several exceptions in a sinigle test on the terminal? (Maybe not a good idea) + captured_exceptions = captured_failures[0]["exception"] + captured_traceback = captured_failures[0]["traceback"] + # Show the cpatured information on the terminal. + capturued_info = [x["info"] for x in captured_failures] + capturued_info_str = f"\n\n{'=' * 80}\n\n".join(capturued_info) + + # Enhance the exception message if there were suppressed failures + if regular_failures_info: + enhanced_message = f"""{str(captured_exceptions)} + +{"=" * 80} +Handled Failures: ({len(capturued_info)} handled): +{"-" * 80}\n +{capturued_info_str} + +{"=" * 80} +Unhandled Failures: ({len(regular_failures_info)} unhandled): +{"-" * 80}\n +{", ".join(f"{info['type']}: {info['message']}{info['file']}:{info['line']}" for info in regular_failures_info)} + +{"-" * 80} +Note: This failure occurred after other failures analyzed by the patched assertion methods. +To see the full details, temporarily disable assertion patching. +{"=" * 80}""" + + # Create new exception with enhanced message + enhanced_exception = type(captured_exceptions)(enhanced_message) + enhanced_exception.__cause__ = captured_exceptions.__cause__ + enhanced_exception.__context__ = captured_exceptions.__context__ + + # Raise with your existing traceback reconstruction + captured_exceptions = enhanced_exception + + # clean up the recorded status + del test_method.__func__.captured_failures + + raise captured_exceptions.with_traceback(captured_traceback) + + +def _patch_with_call_info(module_or_class, attr_name, _parse_call_info_func, target_args): + """ + Patch a callerable `attr_name` of a module or class `module_or_class`. + + This will allow us to collect the call information, e.g. the argument names and values, also the literal expressions + passed as the arguments. + """ + orig_method = getattr(module_or_class, attr_name) + if not callable(orig_method): + return + + def patched(*args, **kwargs): + # If the target callable is not called within a test, simply call it without modification. + if not os.environ.get("PYTEST_CURRENT_TEST", ""): + return orig_method(*args, **kwargs) + + try: + orig_method(*args, **kwargs) + except AssertionError as e: + captured_exception = e + # captured_traceback = e.__traceback__ + ( + full_test_name, + test_file, + test_lineno, + test_obj, + test_method, + test_frame, + test_traceback, + test_code_context, + caller_path, + caller_lineno, + caller_code_context, + test_info, + ) = _get_test_info() + test_info = f"{test_info}\n\n{'-' * 80}\n\npatched method: {orig_method.__module__}.{orig_method.__name__}" + call_argument_expressions = _get_call_arguments(caller_code_context) + + # This is specific + info = _parse_call_info_func(orig_method, args, kwargs, call_argument_expressions, target_args) + info = _prepare_debugging_info(test_info, info) + + # If the test is running in a CI environment (e.g. not a manual run), let's raise and fail the test, so it + # behaves as usual. + # On Github Actions or CircleCI, this is set automatically. + # When running manually, it's the user to determine if to set it. + # This is to avoid the patched function being called `with self.assertRaises(AssertionError):` and fails + # because of the missing expected `AssertionError`. + # TODO (ydshieh): If there is way to raise only when we are inside such context managers? + # TODO (ydshieh): How not to record the failure if it happens inside `self.assertRaises(AssertionError)`? + if os.getenv("CI") == "true": + raise captured_exception.with_traceback(test_traceback) + + # Save this, so we can raise at the end of the current test + captured_failure = { + "result": "failed", + "exception": captured_exception, + "traceback": test_traceback, + "info": info, + } + + # Record the failure status and its information, so we can raise it later. + # We are modifying the (unbound) function at class level: not its logic but only adding a new extra + # attribute. + if getattr(test_method.__func__, "captured_failures", None) is None: + test_method.__func__.captured_failures = {} + if id(test_method) not in test_method.__func__.captured_failures: + test_method.__func__.captured_failures[id(test_method)] = [] + test_method.__func__.captured_failures[id(test_method)].append(captured_failure) + + # This modifies the `tearDown` which will be called after every tests, but we reset it back inside + # `_patched_tearDown`. + if not hasattr(type(test_obj).tearDown, "orig_tearDown"): + orig_tearDown = type(test_obj).tearDown + _patched_tearDown.orig_tearDown = orig_tearDown + type(test_obj).tearDown = _patched_tearDown + + setattr(module_or_class, attr_name, patched) + + +def _parse_call_info(func, args, kwargs, call_argument_expressions, target_args): + """ + Prepare a string containing the call info to `func`, e.g. argument names/values/expressions. + """ + signature = inspect.signature(func) + signature_names = [param.name for param_name, param in signature.parameters.items()] + + # called as `self.method_name()` or `xxx.method_name()`. + if len(args) == len(call_argument_expressions["positional_args"]) + 1: + # We simply add "self" as the expression despite it might not be the actual argument name. + # (This part is very unlikely what a user would be interest to know) + call_argument_expressions["positional_args"] = ["self"] + call_argument_expressions["positional_args"] + + param_position_mapping = {param_name: idx for idx, param_name in enumerate(signature_names)} + + arg_info = {} + for arg_name in target_args: + if arg_name in kwargs: + arg_value = kwargs[arg_name] + arg_expr = call_argument_expressions["keyword_args"][arg_name] + else: + arg_pos = param_position_mapping[arg_name] + arg_value = args[arg_pos] + arg_expr = call_argument_expressions["positional_args"][arg_pos] + + arg_value_str = _format_py_obj(arg_value) + arg_info[arg_name] = {"arg_expr": arg_expr, "arg_value_str": arg_value_str} + + info = "" + for arg_name in arg_info: + arg_expr, arg_value_str = arg_info[arg_name]["arg_expr"], arg_info[arg_name]["arg_value_str"] + info += f"{'-' * 80}\n\nargument name: `{arg_name}`\nargument expression: `{arg_expr}`\n\nargument value:\n\n{arg_value_str}\n\n" + + # remove the trailing \n\n + info = info[:-2] + + return info + + +def patch_testing_methods_to_collect_info(): + """ + Patch some methods (`torch.testing.assert_close`, `unittest.case.TestCase.assertEqual`, etc). + + This will allow us to collect the call information, e.g. the argument names and values, also the literal expressions + passed as the arguments. + """ + p = os.path.join(os.environ.get("_PATCHED_TESTING_METHODS_OUTPUT_DIR", ""), "captured_info.txt") + Path(p).unlink(missing_ok=True) + + if is_torch_available(): + import torch + + _patch_with_call_info(torch.testing, "assert_close", _parse_call_info, target_args=("actual", "expected")) + + _patch_with_call_info(unittest.case.TestCase, "assertEqual", _parse_call_info, target_args=("first", "second")) + _patch_with_call_info(unittest.case.TestCase, "assertListEqual", _parse_call_info, target_args=("list1", "list2")) + _patch_with_call_info( + unittest.case.TestCase, "assertTupleEqual", _parse_call_info, target_args=("tuple1", "tuple2") + ) + _patch_with_call_info(unittest.case.TestCase, "assertSetEqual", _parse_call_info, target_args=("set1", "set1")) + _patch_with_call_info(unittest.case.TestCase, "assertDictEqual", _parse_call_info, target_args=("d1", "d2")) + _patch_with_call_info(unittest.case.TestCase, "assertIn", _parse_call_info, target_args=("member", "container")) + _patch_with_call_info(unittest.case.TestCase, "assertNotIn", _parse_call_info, target_args=("member", "container")) + _patch_with_call_info(unittest.case.TestCase, "assertLess", _parse_call_info, target_args=("a", "b")) + _patch_with_call_info(unittest.case.TestCase, "assertLessEqual", _parse_call_info, target_args=("a", "b")) + _patch_with_call_info(unittest.case.TestCase, "assertGreater", _parse_call_info, target_args=("a", "b")) + _patch_with_call_info(unittest.case.TestCase, "assertGreaterEqual", _parse_call_info, target_args=("a", "b")) + + +def torchrun(script: str, nproc_per_node: int, is_torchrun: bool = True, env: Optional[dict] = None): + """Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary.""" + with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp: + tmp.write(script) + tmp.flush() + tmp.seek(0) + if is_torchrun: + cmd = ( + f"torchrun --nproc_per_node {nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}" + ).split() + else: + cmd = ["python3", tmp.name] + + # Note that the subprocess will be waited for here, and raise an error if not successful + try: + _ = subprocess.run(cmd, capture_output=True, env=env, text=True, check=True) + except subprocess.CalledProcessError as e: + raise Exception(f"The following error was captured: {e.stderr}") + + +def _format_tensor(t, indent_level=0, sci_mode=None): + """Format torch's tensor in a pretty way to be shown 👀 in the test report.""" + + # `torch.testing.assert_close` could accept python int/float numbers. + if not isinstance(t, torch.Tensor): + t = torch.tensor(t) + + # Simply make the processing below simpler (not to hande both case) + is_scalar = False + if t.ndim == 0: + t = torch.tensor([t]) + is_scalar = True + + # For scalar or one-dimensional tensor, keep it as one-line. If there is only one element along any dimension except + # the last one, we also keep it as one-line. + if t.ndim <= 1 or set(t.shape[0:-1]) == {1}: + # Use `detach` to remove `grad_fn=<...>`, and use `to("cpu")` to remove `device='...'` + t = t.detach().to("cpu") + + # We work directly with the string representation instead the tensor itself + t_str = str(t) + + # remove `tensor( ... )` so keep only the content + t_str = t_str.replace("tensor(", "").replace(")", "") + + # Sometimes there are extra spaces between `[` and the first digit of the first value (for alignment). + # For example `[[ 0.06, -0.51], [-0.76, -0.49]]`. It may have multiple consecutive spaces. + # Let's remove such extra spaces. + while "[ " in t_str: + t_str = t_str.replace("[ ", "[") + + # Put everything in a single line. We replace `\n` by a space ` ` so we still keep `,\n` as `, `. + t_str = t_str.replace("\n", " ") + + # Remove repeated spaces (introduced by the previous step) + while " " in t_str: + t_str = t_str.replace(" ", " ") + + # remove leading `[` and `]` for scalar tensor + if is_scalar: + t_str = t_str[1:-1] + + t_str = " " * 4 * indent_level + t_str + + return t_str + + # Otherwise, we separte the representations of every elements along an outer dimension by new lines (after a `,`). + # The representatioin each element is obtained by calling this function recursively with corrent `indent_level`. + else: + t_str = str(t) + + # (For the recursive calls should receive this value) + if sci_mode is None: + sci_mode = "e+" in t_str or "e-" in t_str + + # Use the original content to determine the scientific mode to use. This is required as the representation of + # t[index] (computed below) maybe have different format regarding scientific notation. + torch.set_printoptions(sci_mode=sci_mode) + + t_str = " " * 4 * indent_level + "[\n" + # Keep the ending `,` for all outer dimensions whose representations are not put in one-line, even if there is + # only one element along that dimension. + t_str += ",\n".join(_format_tensor(x, indent_level=indent_level + 1, sci_mode=sci_mode) for x in t) + t_str += ",\n" + " " * 4 * indent_level + "]" + + torch.set_printoptions(sci_mode=None) + + return t_str + + +def _quote_string(s): + """Given a string `s`, return a python literal expression that give `s` when it is used in a python source code. + + For example, if `s` is the string `abc`, the return value is `"abc"`. + + We choice double quotes over single quote despite `str(s)` would give `'abc'` instead of `"abc"`. + """ + has_single_quote = "'" in s + has_double_quote = '"' in s + + if has_single_quote and has_double_quote: + # replace any double quote by the raw string r'\"'. + s = s.replace('"', r"\"") + return f'"{s}"' + elif has_single_quote: + return f'"{s}"' + elif has_double_quote: + return f"'{s}'" + else: + return f'"{s}"' + + +def _format_py_obj(obj, indent=0, mode="", cache=None, prefix=""): + """Format python objects of basic built-in type in a pretty way so we could copy-past them to code editor easily. + + Currently, this support int, float, str, list, tuple, and dict. + + It also works with `torch.Tensor` via calling `format_tesnor`. + """ + + if cache is None: + cache = {} + else: + if (id(obj), indent, mode, prefix) in cache: + return cache[(id(obj), indent, mode, prefix)] + + # special format method for `torch.Tensor` + if str(obj.__class__) == "": + return _format_tensor(obj) + + elif obj.__class__.__name__ == "str": + quoted_string = _quote_string(obj) + # we don't want the newline being interpreted + quoted_string = quoted_string.replace("\n", r"\n") + output = quoted_string + + elif obj.__class__.__name__ in ["int", "float"]: + # for float like `1/3`, we will get `0.3333333333333333` + output = str(obj) + + elif obj.__class__.__name__ in ["list", "tuple", "dict"]: + parenthesis = { + "list": "[]", + "tuple": "()", + "dict": "{}", + } + p1, p2 = parenthesis[obj.__class__.__name__] + + elements_without_indent = [] + if isinstance(obj, dict): + for idx, (k, v) in enumerate(obj.items()): + last_element = idx == len(obj) - 1 + ok = _format_py_obj(k, indent=indent + 1, mode="one-line", cache=cache) + ov = _format_py_obj( + v, + indent=indent + 1, + mode=mode, + cache=cache, + prefix=ok.lstrip() + ": " + "," if not last_element else "", + ) + # Each element could be multiple-line, but the indent of its first line is removed + elements_without_indent.append(f"{ok.lstrip()}: {ov.lstrip()}") + + else: + for idx, x in enumerate(obj): + last_element = idx == len(obj) - 1 + o = _format_py_obj( + x, indent=indent + 1, mode=mode, cache=cache, prefix="," if not last_element else "" + ) + # Each element could be multiple-line, but the indent of its first line is removed + elements_without_indent.append(o.lstrip()) + + groups = [] + buf = [] + for idx, x in enumerate(elements_without_indent): + buf.append(x) + + x_expanded = "\n" in buf[-1] + not_last_element = idx != len(elements_without_indent) - 1 + # if `x` should be separated from subsequent elements + should_finalize_x = x_expanded or len(f"{' ' * (4 * (indent + 1))}") + len( + ", ".join(buf[-1:]) + ) > 120 - int(not_last_element) + + # if `buf[:-1]` (i.e. without `x`) should be combined together (into one line) + should_finalize_buf = x_expanded + + # the recursive call returns single line, so we can use it to determine if we can fit the width limit + if not should_finalize_buf: + buf_not_fit_into_one_line = len(f"{' ' * (4 * (indent + 1))}") + len(", ".join(buf)) > 120 - int( + not_last_element + ) + should_finalize_buf = buf_not_fit_into_one_line + + # any element of iterable type need to be on its own line + if (type(obj[idx]) if type(obj) is not dict else type(list(obj.values())[idx])) in [list, tuple, dict]: + should_finalize_x = True + should_finalize_buf = True + + # any type change --> need to be added after a new line + prev_type = None + current_type = type(obj[idx]) if type(obj) is not dict else type(list(obj.values())[idx]) + if len(buf) > 1: + prev_type = type(obj[idx - 1]) if type(obj) is not dict else type(list(obj.values())[idx - 1]) + type_changed = current_type != prev_type + if type_changed: + should_finalize_buf = True + + # all elements in the buf are string --> don't finalize the buf by width limit + if prev_type is None or (prev_type is str and current_type is str): + should_finalize_buf = False + + # collect as many elements of string type as possible (without width limit). + # These will be examined as a whole (if not fit into the width, each element would be in its own line) + if current_type is str: + should_finalize_x = False + # `len(buf) == 1` or `obj[idx-1]` is a string + if prev_type in [None, str]: + should_finalize_buf = False + + if should_finalize_buf: + orig_buf_len = len(buf) + + if orig_buf_len > 1: + not_fit_into_one_line = None + + # all elements in `obj` that give `buf[:-1]` are string. + if prev_type is str: + # `-1` at the end: because buf[-2] is not the last element + not_fit_into_one_line = len(f"{' ' * (4 * (indent + 1))}") + len(", ".join(buf[:-1])) > 120 - 1 + + if not_fit_into_one_line: + for x in buf[:-1]: + groups.append([x]) + else: + groups.append(buf[:-1]) + + buf = buf[-1:] + + if should_finalize_x: + groups.append(buf) + buf = [] + + # The last buf + if len(buf) > 0: + not_fit_into_one_line = None + if current_type is str: + # no `-1` at the end: because buf[-1] is the last element + not_fit_into_one_line = len(f"{' ' * (4 * (indent + 1))}") + len(", ".join(buf)) > 120 + + if not_fit_into_one_line: + for x in buf: + groups.append([x]) + else: + groups.append(buf) + + output = f"{' ' * 4 * indent}{p1}\n" + element_strings = [f"{' ' * (4 * (indent + 1))}" + ", ".join(buf) for buf in groups] + output += ",\n".join(element_strings) + output += f"\n{' ' * 4 * indent}{p2}" + + # if all elements are in one-line + no_new_line_in_elements = all("\n" not in x for x in element_strings) + # if yes, we can form a one-line representation of `obj` + could_use_one_line = no_new_line_in_elements + + # if mode == "one-line", this function always returns one-line representation, so `no_new_line_in_elements` + # will be `True`. + if could_use_one_line: + one_line_form = ", ".join([x.lstrip() for x in element_strings]) + one_line_form = f"{p1}{one_line_form}{p2}" + + if mode == "one-line": + return output + + # check with the width limit + could_use_one_line = len(f"{' ' * 4 * indent}") + len(prefix) + len(one_line_form) <= 120 + + # extra conditions for returning one-line representation + def use_one_line_repr(obj): + # interable types + if type(obj) in (list, tuple, dict): + # get all types + element_types = [] + if type(obj) is dict: + element_types.extend(type(x) for x in obj.values()) + elif type(obj) in [list, tuple]: + element_types.extend(type(x) for x in obj) + + # At least one element is of iterable type + if any(x in (list, tuple, dict) for x in element_types): + # If `obj` has more than one element and at least one of them is iterable --> no one line repr. + if len(obj) > 1: + return False + + # only one element that is iterable, but not the same type as `obj` --> no one line repr. + if type(obj) is not type(obj[0]): + return False + + # one-line repr. if possible, without width limit + return no_new_line_in_elements + + # all elements are of simple types, but more than one type --> no one line repr. + if len(set(element_types)) > 1: + return False + + # all elements are of the same simple type + if element_types[0] in [int, float]: + # one-line repr. without width limit + return no_new_line_in_elements + elif element_types[0] is str: + if len(obj) == 1: + # one single string element --> one-line repr. without width limit + return no_new_line_in_elements + else: + # multiple string elements --> one-line repr. if fit into width limit + return could_use_one_line + + # simple types (int, flat, string) + return True + + # width condition combined with specific mode conditions + if use_one_line_repr(obj): + output = f"{' ' * 4 * indent}{one_line_form}" + + cache[(id(obj), indent, mode, prefix)] = output + + return output diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tf_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tf_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..11d07f8d7edab7401bd4582b57a095db6552475a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tf_utils.py @@ -0,0 +1,294 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import numpy as np +import tensorflow as tf + +from .feature_extraction_utils import BatchFeature +from .tokenization_utils_base import BatchEncoding +from .utils import logging + + +logger = logging.get_logger(__name__) + + +def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> list[int]: + """ + Deal with dynamic shape in tensorflow cleanly. + + Args: + tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of. + + Returns: + `list[int]`: The shape of the tensor as a list. + """ + if isinstance(tensor, np.ndarray): + return list(tensor.shape) + + dynamic = tf.shape(tensor) + + if tensor.shape == tf.TensorShape(None): + return dynamic + + static = tensor.shape.as_list() + + return [dynamic[i] if s is None else s for i, s in enumerate(static)] + + +def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional[str] = None) -> tf.Tensor: + """ + Stable wrapper that returns the same output as `tf.nn.softmax`, but that works reliably with XLA on CPU. It is + meant as a workaround for the [following issue](https://github.com/tensorflow/tensorflow/issues/55682), and will be + removed after it gets fixed. The arguments and outputs are the same as `tf.nn.softmax`, and relies on the fact that + `softmax(x) = softmax(x + c)` (see https://ogunlao.github.io/2020/04/26/you_dont_really_know_softmax.html). + + Args: + logits (`tf.Tensor`): + Must be one of the following types: half, float32, float64. + axis (`int`, *optional*): + The dimension softmax would be performed on. The default is -1 which indicates the last dimension. + name (`str`, *optional*): + A name for the operation. + + Returns: + `tf.Tensor`: + A Tensor. Has the same type and shape as logits. + """ + # TODO: When the issue linked above gets sorted, add a check on TF version here and use the original function if + # it has the fix. After we drop the support for unfixed versions, remove this function. + return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name) + + +def functional_layernorm(inputs, weight, bias, epsilon=1e-5, axis=-1): + # This is a very simplified functional layernorm, designed to duplicate + # the functionality of PyTorch nn.functional.layer_norm when this is needed to port + # models in Transformers. + + if weight.shape.rank != 1 or bias.shape.rank != 1 or not isinstance(axis, int): + raise NotImplementedError("Only 1D weight and bias tensors are supported for now, with only a single axis.") + + # Get mean and variance on the axis to be normalized + mean, variance = tf.nn.moments(inputs, axes=[axis], keepdims=True) + + if axis != -1: + # Reshape scale and weight to have the same rank as inputs, but with 1 dimensions + # on every dimension except axis + shape = [1] * inputs.shape.rank + shape[axis] = shape_list(inputs)[axis] + weight = tf.reshape(weight, shape) + bias = tf.reshape(bias, shape) + + # Compute layer normalization using the batch_normalization + # function. + outputs = tf.nn.batch_normalization( + inputs, + mean, + variance, + offset=bias, + scale=weight, + variance_epsilon=epsilon, + ) + return outputs + + +def scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale: Optional[float] = None +): + """TF equivalent for torch's nn.functional.scaled_dot_product_attention""" + if dropout_p != 0.0: + raise ValueError( + "Dropout is not supported in this implementation - file an issue " + "with Transformers and ping @Rocketknight1 if you need it for a port!" + ) + if is_causal and attn_mask is not None: + raise ValueError("You cannot specify an attn_mask and is_causal at the same time!") + if is_causal: + attn_mask = tf.ones((tf.shape(query)[-2], tf.shape(key)[-2]), dtype=tf.int32) + attn_mask = tf.experimental.numpy.tril(attn_mask, k=0) + if attn_mask is not None and (attn_mask.dtype.is_integer or attn_mask.dtype.is_bool): + # Convert boolean mask to a negative logit bias + attn_mask = tf.where(attn_mask > 0, tf.cast(0.0, query.dtype), tf.cast(-1000.0, query.dtype)) + logits = tf.einsum("...qd, ...kd -> ...qk", query, key) + if scale is None: + scale = tf.cast(tf.shape(key)[-1], logits.dtype) ** -0.5 + logits *= scale # scale by 1/sqrt(key_dim) + if attn_mask is not None: + logits += attn_mask + probs = tf.nn.softmax(logits) + return probs @ value + + +def flatten(input, start_dim=0, end_dim=-1): + # Replicates the behavior of torch.flatten in TF + + # If end_dim or start_dim is negative, count them from the end + if end_dim < 0: + end_dim += input.shape.rank + if start_dim < 0: + start_dim += input.shape.rank + + if start_dim == end_dim: + return input + + in_shape = tf.shape(input) + flattened_dim = tf.math.reduce_prod(in_shape[start_dim : end_dim + 1]) + out_shape = tf.concat([in_shape[:start_dim], [flattened_dim], in_shape[end_dim + 1 :]], axis=0) + return tf.reshape(input, out_shape) + + +def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor: + """ + Invert an attention mask (e.g., switches 0. and 1.). + + Args: + encoder_attention_mask (`torch.Tensor`): An attention mask. + + Returns: + `tf.Tensor`: The inverted attention mask. + """ + if not isinstance(encoder_attention_mask, tf.Tensor): + encoder_attention_mask = tf.convert_to_tensor(encoder_attention_mask) # Catches stray NumPy inputs + if encoder_attention_mask.shape.rank == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if encoder_attention_mask.shape.rank == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow + # /transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = (encoder_extended_attention_mask == + # encoder_extended_attention_mask.transpose(-1, -2)) + encoder_extended_attention_mask = ( + tf.cast(1, encoder_attention_mask.dtype) - encoder_extended_attention_mask + ) * encoder_extended_attention_mask.dtype.min + + return encoder_extended_attention_mask + + +def check_embeddings_within_bounds(tensor: tf.Tensor, embed_dim: int, tensor_name: str = "input_ids") -> None: + """ + `tf.gather`, on which TF embedding layers are based, won't check positive out of bound indices on GPU, returning + zeros instead. This function adds a check against that dangerous silent behavior. + + Args: + tensor (`tf.Tensor`): The tensor of indices to check. + embed_dim (`int`): The embedding dimension. + tensor_name (`str`, *optional*): The name of the tensor to use in the error message. + """ + tf.debugging.assert_less( + tensor, + tf.cast(embed_dim, dtype=tensor.dtype), + message=( + f"The maximum value of {tensor_name} ({tf.math.reduce_max(tensor)}) must be smaller than the embedding " + f"layer's input dimension ({embed_dim}). The likely cause is some problem at tokenization time." + ), + ) + + +def save_attributes_to_hdf5_group(group, name, data): + """Saves attributes (data) of the specified name into the HDF5 group. + + This method deals with an inherent problem of HDF5 file which is not able to store data larger than + HDF5_OBJECT_HEADER_LIMIT bytes. + + Args: + group: A pointer to a HDF5 group. + name: A name of the attributes to save. + data: Attributes data to store. + + Raises: + RuntimeError: If any single attribute is too large to be saved. + + Copied from Keras to Transformers to avoid versioning issues. + """ + HDF5_OBJECT_HEADER_LIMIT = 64512 + # Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT` + # because in that case even chunking the array would not make the saving + # possible. + bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT] + + # Expecting this to never be true. + if bad_attributes: + raise RuntimeError( + "The following attributes cannot be saved to HDF5 file because " + f"they are larger than {HDF5_OBJECT_HEADER_LIMIT} " + f"bytes: {bad_attributes}" + ) + + data_npy = np.asarray(data) + + num_chunks = 1 + chunked_data = np.array_split(data_npy, num_chunks) + + # This will never loop forever thanks to the test above. + while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data): + num_chunks += 1 + chunked_data = np.array_split(data_npy, num_chunks) + + if num_chunks > 1: + for chunk_id, chunk_data in enumerate(chunked_data): + group.attrs["%s%d" % (name, chunk_id)] = chunk_data + else: + group.attrs[name] = data + + +def load_attributes_from_hdf5_group(group, name): + """Loads attributes of the specified name from the HDF5 group. + + This method deals with an inherent problem of HDF5 file which is not able to store data larger than + HDF5_OBJECT_HEADER_LIMIT bytes. + + Args: + group: A pointer to a HDF5 group. + name: A name of the attributes to load. + + Returns: + data: Attributes data. + + Copied from Keras to Transformers to avoid versioning issues. + """ + if name in group.attrs: + data = [n.decode("utf8") if hasattr(n, "decode") else n for n in group.attrs[name]] + else: + data = [] + chunk_id = 0 + while "%s%d" % (name, chunk_id) in group.attrs: + data.extend( + [n.decode("utf8") if hasattr(n, "decode") else n for n in group.attrs["%s%d" % (name, chunk_id)]] + ) + chunk_id += 1 + return data + + +def expand_1d(data): + """Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s. + Copied from Keras to here to avoid versioning issues.""" + + def _expand_single_1d_tensor(t): + if isinstance(t, tf.Tensor) and t.shape.rank == 1: + return tf.expand_dims(t, axis=-1) + return t + + return tf.nest.map_structure(_expand_single_1d_tensor, data) + + +def convert_batch_encoding(*args, **kwargs): + # Convert HF BatchEncoding/BatchFeature objects in the inputs to dicts that Keras understands + if args and isinstance(args[0], (BatchEncoding, BatchFeature)): + args = list(args) + args[0] = dict(args[0]) + elif "x" in kwargs and isinstance(kwargs["x"], (BatchEncoding, BatchFeature)): + kwargs["x"] = dict(kwargs["x"]) + return args, kwargs diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/time_series_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/time_series_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3a5cf4f2f4d8f636e623c07ae3400ca5e17b5891 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/time_series_utils.py @@ -0,0 +1,225 @@ +# Copyright 2023 The HuggingFace Inc. team. +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Time series distributional output classes and utilities. +""" + +from typing import Callable, Optional + +import torch +from torch import nn +from torch.distributions import ( + AffineTransform, + Distribution, + Independent, + NegativeBinomial, + Normal, + StudentT, + TransformedDistribution, +) + + +class AffineTransformed(TransformedDistribution): + def __init__(self, base_distribution: Distribution, loc=None, scale=None, event_dim=0): + self.scale = 1.0 if scale is None else scale + self.loc = 0.0 if loc is None else loc + + super().__init__(base_distribution, [AffineTransform(loc=self.loc, scale=self.scale, event_dim=event_dim)]) + + @property + def mean(self): + """ + Returns the mean of the distribution. + """ + return self.base_dist.mean * self.scale + self.loc + + @property + def variance(self): + """ + Returns the variance of the distribution. + """ + return self.base_dist.variance * self.scale**2 + + @property + def stddev(self): + """ + Returns the standard deviation of the distribution. + """ + return self.variance.sqrt() + + +class ParameterProjection(nn.Module): + def __init__( + self, in_features: int, args_dim: dict[str, int], domain_map: Callable[..., tuple[torch.Tensor]], **kwargs + ) -> None: + super().__init__(**kwargs) + self.args_dim = args_dim + self.proj = nn.ModuleList([nn.Linear(in_features, dim) for dim in args_dim.values()]) + self.domain_map = domain_map + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]: + params_unbounded = [proj(x) for proj in self.proj] + + return self.domain_map(*params_unbounded) + + +class LambdaLayer(nn.Module): + def __init__(self, function): + super().__init__() + self.function = function + + def forward(self, x, *args): + return self.function(x, *args) + + +class DistributionOutput: + distribution_class: type + in_features: int + args_dim: dict[str, int] + + def __init__(self, dim: int = 1) -> None: + self.dim = dim + self.args_dim = {k: dim * self.args_dim[k] for k in self.args_dim} + + def _base_distribution(self, distr_args): + if self.dim == 1: + return self.distribution_class(*distr_args) + else: + return Independent(self.distribution_class(*distr_args), 1) + + def distribution( + self, + distr_args, + loc: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + ) -> Distribution: + distr = self._base_distribution(distr_args) + if loc is None and scale is None: + return distr + else: + return AffineTransformed(distr, loc=loc, scale=scale, event_dim=self.event_dim) + + @property + def event_shape(self) -> tuple: + r""" + Shape of each individual event contemplated by the distributions that this object constructs. + """ + return () if self.dim == 1 else (self.dim,) + + @property + def event_dim(self) -> int: + r""" + Number of event dimensions, i.e., length of the `event_shape` tuple, of the distributions that this object + constructs. + """ + return len(self.event_shape) + + @property + def value_in_support(self) -> float: + r""" + A float that will have a valid numeric value when computing the log-loss of the corresponding distribution. By + default 0.0. This value will be used when padding data series. + """ + return 0.0 + + def get_parameter_projection(self, in_features: int) -> nn.Module: + r""" + Return the parameter projection layer that maps the input to the appropriate parameters of the distribution. + """ + return ParameterProjection( + in_features=in_features, + args_dim=self.args_dim, + domain_map=LambdaLayer(self.domain_map), + ) + + def domain_map(self, *args: torch.Tensor): + r""" + Converts arguments to the right shape and domain. The domain depends on the type of distribution, while the + correct shape is obtained by reshaping the trailing axis in such a way that the returned tensors define a + distribution of the right event_shape. + """ + raise NotImplementedError() + + @staticmethod + def squareplus(x: torch.Tensor) -> torch.Tensor: + r""" + Helper to map inputs to the positive orthant by applying the square-plus operation. Reference: + https://twitter.com/jon_barron/status/1387167648669048833 + """ + return (x + torch.sqrt(torch.square(x) + 4.0)) / 2.0 + + +class StudentTOutput(DistributionOutput): + """ + Student-T distribution output class. + """ + + args_dim: dict[str, int] = {"df": 1, "loc": 1, "scale": 1} + distribution_class: type = StudentT + + @classmethod + def domain_map(cls, df: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor): + scale = cls.squareplus(scale).clamp_min(torch.finfo(scale.dtype).eps) + df = 2.0 + cls.squareplus(df) + return df.squeeze(-1), loc.squeeze(-1), scale.squeeze(-1) + + +class NormalOutput(DistributionOutput): + """ + Normal distribution output class. + """ + + args_dim: dict[str, int] = {"loc": 1, "scale": 1} + distribution_class: type = Normal + + @classmethod + def domain_map(cls, loc: torch.Tensor, scale: torch.Tensor): + scale = cls.squareplus(scale).clamp_min(torch.finfo(scale.dtype).eps) + return loc.squeeze(-1), scale.squeeze(-1) + + +class NegativeBinomialOutput(DistributionOutput): + """ + Negative Binomial distribution output class. + """ + + args_dim: dict[str, int] = {"total_count": 1, "logits": 1} + distribution_class: type = NegativeBinomial + + @classmethod + def domain_map(cls, total_count: torch.Tensor, logits: torch.Tensor): + total_count = cls.squareplus(total_count) + return total_count.squeeze(-1), logits.squeeze(-1) + + def _base_distribution(self, distr_args) -> Distribution: + total_count, logits = distr_args + if self.dim == 1: + return self.distribution_class(total_count=total_count, logits=logits) + else: + return Independent(self.distribution_class(total_count=total_count, logits=logits), 1) + + # Overwrites the parent class method. We cannot scale using the affine + # transformation since negative binomial should return integers. Instead + # we scale the parameters. + def distribution( + self, distr_args, loc: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None + ) -> Distribution: + total_count, logits = distr_args + + if scale is not None: + # See scaling property of Gamma. + logits += scale.log() + + return self._base_distribution((total_count, logits)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tokenization_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tokenization_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b89e570931526a24555b4ff86ee1797f906f41fc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tokenization_utils.py @@ -0,0 +1,1135 @@ +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tokenization classes for python tokenizers. For fast tokenizers (provided by HuggingFace's tokenizers library) see +tokenization_utils_fast.py +""" + +import bisect +import itertools +import re +import unicodedata +from collections import OrderedDict +from typing import Any, Optional, Union, overload + +from .tokenization_utils_base import ( + ENCODE_KWARGS_DOCSTRING, + ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING, + INIT_TOKENIZER_DOCSTRING, + AddedToken, + BatchEncoding, + EncodedInput, + EncodedInputPair, + PreTokenizedInput, + PreTokenizedInputPair, + PreTrainedTokenizerBase, + TextInput, + TextInputPair, + TruncationStrategy, +) +from .utils import PaddingStrategy, TensorType, add_end_docstrings, logging + + +logger = logging.get_logger(__name__) + +# Slow tokenizers are saved in a vocabulary plus three separated files +SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" +ADDED_TOKENS_FILE = "added_tokens.json" +TOKENIZER_CONFIG_FILE = "tokenizer_config.json" + + +class Trie: + """ + Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass + Loose reference https://en.wikipedia.org/wiki/Trie + """ + + def __init__(self, *args): + self.data = {} + self._tokens = set() + self._termination_char = "" + self.update(*args) + + def update(self, *args): + """ + Updates the Trie with new tokens provided as arguments. + + Args: + *args: Variable number of words to be added to the Trie. + """ + for token in tuple(*args): + self.add(token) + + def add(self, word: str): + """ + Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation. + The special key `""` in `self._termination_char` is used to represent termination. + + This function is idempotent, adding twice the same word will leave the trie unchanged + + Example: + + ```python + >>> trie = Trie() + >>> trie.add("Hello 友達") + >>> trie.data + {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}} + + >>> trie.add("Hello") + >>> trie.data + {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}} + ``` + """ + if not word: + # Prevent empty string + return + + self._tokens.add(word) + ref = self.data + for char in word: + ref[char] = ref.setdefault(char, {}) + ref = ref[char] + ref[self._termination_char] = 1 + + def split(self, text: str) -> list[str]: + """ + Will look for the words added to the trie within `text`. Output is the original string split along the + boundaries of the words found. + + This trie will match the longest possible word first ! + + Example: + + ```python + >>> trie = Trie() + >>> trie.split("[CLS] This is a extra_id_100") + ["[CLS] This is a extra_id_100"] + + >>> trie.add("[CLS]") + >>> trie.add("extra_id_1") + >>> trie.add("extra_id_100") + >>> trie.split("[CLS] This is a extra_id_100") + ["[CLS]", " This is a ", "extra_id_100"] + ``` + """ + # indexes are counted left of the chars index. + # "hello", index 0, is left of h, index 1 is between h and e. + # index 5 is right of the "o". + + # States are going to capture every possible start (indexes as above) + # as keys, and have as values, a pointer to the position in the trie + # where we're at. This is a partial match for now. + # This enables to keep track of multiple matches while we're iterating + # the string + # If the trie contains, "blowing", and "lower" and we encounter the + # string "blower", we need to split into ["b", "lower"]. + # This is where we need to keep track of multiple possible starts. + states = OrderedDict() + + # This will contain every indices where we need + # to cut. + # We force to cut at offset 0 and len(text) (added later) + offsets = [0] + + # This is used by the lookahead which needs to skip over + # some text where the full match exceeded the place in the initial + # for loop + skip = 0 + # Main loop, Giving this algorithm O(n) complexity + for current, current_char in enumerate(text): + if skip and current < skip: + # Prevents the lookahead for matching twice + # like extra_id_100 and id_100 + continue + + # This will track every state + # that stop matching, we need to stop tracking them. + # If we look at "lowball", we're going to match "l" (add it to states), "o", "w", then + # fail on "b", we need to remove 0 from the valid states. + to_remove = set() + # Whenever we found a match, we need to drop everything + # this is a greedy algorithm, it will match on the first found token + reset = False + + # In this case, we already have partial matches (But unfinished) + for start, trie_pointer in states.items(): + if "" in trie_pointer: + # This is a final match, we need to reset and + # store the results in `offsets`. + + # Lookahead to match longest first + # Important in case of extra_id_1 vs extra_id_100 + # Here we are also actively looking for other earlier partial + # matches + # "[CLS]", "L", we need to match CLS even if L is special + for lookstart, looktrie_pointer in states.items(): + if lookstart > start: + # This partial match is later, we can stop looking + break + elif lookstart < start: + # This partial match is earlier, the trie pointer + # was already updated, so index is + 1 + lookahead_index = current + 1 + end = current + 1 + else: + # Here lookstart == start and + # looktrie_pointer == trie_pointer + # It wasn't updated yet so indices are current ones + lookahead_index = current + end = current + next_char = text[lookahead_index] if lookahead_index < len(text) else None + if "" in looktrie_pointer: + start = lookstart + end = lookahead_index + skip = lookahead_index + + while next_char in looktrie_pointer: + looktrie_pointer = looktrie_pointer[next_char] + lookahead_index += 1 + if "" in looktrie_pointer: + start = lookstart + end = lookahead_index + skip = lookahead_index + + if lookahead_index == len(text): + # End of string + break + next_char = text[lookahead_index] + # End lookahead + + # Storing and resetting + offsets.append(start) + offsets.append(end) + reset = True + break + elif current_char in trie_pointer: + # The current character being looked at has a match within the trie + # update the pointer (it will be stored back into states later). + trie_pointer = trie_pointer[current_char] + + # Storing back the new pointer into the states. + # Partial matches got longer by one. + states[start] = trie_pointer + else: + # The new character has not match in the trie, we need + # to stop keeping track of this partial match. + # We can't do it directly within the loop because of how + # python iteration works + to_remove.add(start) + + # Either clearing the full start (we found a real match) + # Or clearing only the partial matches that didn't work. + if reset: + states = {} + else: + for start in to_remove: + del states[start] + + # If this character is a starting character within the trie + # start keeping track of this partial match. + if current >= skip and current_char in self.data: + states[current] = self.data[current_char] + + # We have a cut at the end with states. + for start, trie_pointer in states.items(): + if "" in trie_pointer: + # This is a final match, we need to reset and + # store the results in `offsets`. + end = len(text) + offsets.append(start) + offsets.append(end) + # Longest cut is always the one with lower start so the first + # item so we need to break. + break + + return self.cut_text(text, offsets) + + def cut_text(self, text, offsets): + # We have all the offsets now, we just need to do the actual splitting. + # We need to eventually add the first part of the string and the eventual + # last part. + offsets.append(len(text)) + tokens = [] + start = 0 + for end in offsets: + if start > end: + logger.error( + "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it" + " anyway." + ) + continue + elif start == end: + # This might happen if there's a match at index 0 + # we're also preventing zero-width cuts in case of two + # consecutive matches + continue + tokens.append(text[start:end]) + start = end + + return tokens + + +class ExtensionsTrie(Trie): + def __init__(self, *args): + super().__init__(*args) + + def extensions(self, prefix: str): + """ + Generates all extensions of a given prefix token in the Trie. + + Example: + + ```python + >>> trie = Trie() + >>> trie.add("apple") + >>> trie.add("app") + >>> trie.add("application") + >>> trie.extensions("app") + ['app', 'apple', 'application'] + ``` + """ + prefix_node = self._get_node(prefix) + ret = self._collect_tokens(prefix_node) + return [prefix + token for token in ret] + + def _get_node(self, token: str) -> dict: + """ + Retrieves the node corresponding to the given token in the Trie. + + Args: + token (str): The token for which the corresponding node needs to be retrieved. + + Returns: + dict: The node in the Trie corresponding to the given token. + """ + node = self.data + for char in token: + if char not in node: + break + + node = node[char] + return node + + def _collect_tokens(self, node: dict) -> list: + """ + Generates all tokens in the Trie starting from a given node. + + Args: + node (dict): The node in the Trie from which tokens need to be generated. + + Returns: + list: List of tokens generated from the given node. + """ + tokens = [self._termination_char] if self._termination_char in node else [] + for token, subtrie_head in node.items(): + if token != self._termination_char: + subtokens = self._collect_tokens(subtrie_head) + tokens.extend([token + subtoken for subtoken in subtokens]) + return tokens + + +def _is_whitespace(char): + """Checks whether `char` is a whitespace character.""" + # \t, \n, and \r are technically control characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `char` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `char` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False + + +def _is_end_of_word(text): + """Checks whether the last character in text is one of a punctuation, control or whitespace character.""" + last_char = text[-1] + return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char)) + + +def _is_start_of_word(text): + """Checks whether the first character in text is one of a punctuation, control or whitespace character.""" + first_char = text[0] + return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char)) + + +def _insert_one_token_to_ordered_list(token_list: list[str], new_token: str): + """ + Inserts one token to an ordered list if it does not already exist. Note: token_list must be sorted. + """ + insertion_idx = bisect.bisect_left(token_list, new_token) + # Checks if new_token is already in the ordered token_list + if insertion_idx < len(token_list) and token_list[insertion_idx] == new_token: + # new_token is in token_list, don't add + return + else: + token_list.insert(insertion_idx, new_token) + + +@add_end_docstrings(INIT_TOKENIZER_DOCSTRING) +class PreTrainedTokenizer(PreTrainedTokenizerBase): + """ + Base class for all slow tokenizers. + + Inherits from [`~tokenization_utils_base.PreTrainedTokenizerBase`]. + + Handle all the shared methods for tokenization and special tokens as well as methods downloading/caching/loading + pretrained tokenizers as well as adding tokens to the vocabulary. + + This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the + specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...). + """ + + def __init__(self, **kwargs): + # 1. Init the parent class + + self.tokens_trie = Trie() + + # 2. init `_added_tokens_decoder` if child class did not + if not hasattr(self, "_added_tokens_decoder"): + self._added_tokens_decoder: dict[int, AddedToken] = {} + + # 3. if a `added_tokens_decoder` is passed, we are loading from a saved tokenizer, we overwrite + self._added_tokens_decoder.update(kwargs.pop("added_tokens_decoder", {})) + self._added_tokens_encoder: dict[str, int] = {k.content: v for v, k in self._added_tokens_decoder.items()} + + # 4 init the parent class + super().__init__(**kwargs) + + # 4. If some of the special tokens are not part of the vocab, we add them, at the end. + # the order of addition is the same as self.SPECIAL_TOKENS_ATTRIBUTES following `tokenizers` + self._add_tokens( + [token for token in self.all_special_tokens_extended if token not in self._added_tokens_encoder], + special_tokens=True, + ) + + self._decode_use_source_tokenizer = False + + @property + def is_fast(self) -> bool: + return False + + @property + def vocab_size(self) -> int: + """ + `int`: Size of the base vocabulary (without the added tokens). + """ + raise NotImplementedError + + @property + def added_tokens_encoder(self) -> dict[str, int]: + """ + Returns the sorted mapping from string to index. The added tokens encoder is cached for performance + optimisation in `self._added_tokens_encoder` for the slow tokenizers. + """ + return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])} + + @property + def added_tokens_decoder(self) -> dict[int, AddedToken]: + """ + Returns the added tokens in the vocabulary as a dictionary of index to AddedToken. + + Returns: + `dict[str, int]`: The added tokens. + """ + return dict(sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])) + + @added_tokens_decoder.setter + def added_tokens_decoder(self, value: dict[int, Union[AddedToken, str]]) -> dict[int, AddedToken]: + # Always raise an error if string because users should define the behavior + for index, token in value.items(): + if not isinstance(token, (str, AddedToken)) or not isinstance(index, int): + raise TypeError( + f"The provided `added_tokens_decoder` has an element of type {index.__class__, token.__class__}, should be a dict of {int, Union[AddedToken, str]}" + ) + + self._added_tokens_decoder[index] = AddedToken(token) if isinstance(token, str) else token + self._added_tokens_encoder[str(token)] = index + self._update_total_vocab_size() + + def get_added_vocab(self) -> dict[str, int]: + """ + Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from + the fast call because for now we always add the tokens even if they are already in the vocabulary. This is + something we should change. + + Returns: + `dict[str, int]`: The added tokens. + """ + return self._added_tokens_encoder + + def __len__(self): + """ + Size of the full vocabulary with the added tokens. + """ + return self.total_vocab_size + + def _update_total_vocab_size(self): + """ + Update the size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because + otherwise if there is a hole in the vocab, we will add tokenizers at a wrong index. This operation is slow and + is only updated when adding tokens. + """ + self.total_vocab_size = len(self.get_vocab()) + + def _add_tokens(self, new_tokens: Union[list[str], list[AddedToken]], special_tokens: bool = False) -> int: + """ + Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to + it with indices starting from length of the current vocabulary. Special tokens are sometimes already in the + vocab which is why they have to be handled specifically. + + Args: + new_tokens (`list[str]`or `list[tokenizers.AddedToken]`): + Token(s) to add in vocabulary. A token is counted as added if it's not already in the vocabulary + (tested by checking if the tokenizer assign the index of the `unk_token` to them). If a token is part + of the vocabulary then we simply mark this token as an `AddedToken` which allows to control the + stripping and normalization of this token. This is NOT possible in `tokenizers`. + special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the tokens should be added as special tokens. + + Returns: + `int`: The number of tokens actually added to the vocabulary. + + Examples: + + ```python + # Let's see how to increase the vocabulary of Bert model and tokenizer + tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased") + model = BertModel.from_pretrained("google-bert/bert-base-uncased") + + num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"]) + print("We have added", num_added_toks, "tokens") + # Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer. + model.resize_token_embeddings(len(tokenizer)) + ```""" + added_tokens = 0 + if new_tokens is None: + return added_tokens + # TODO this is fairly slow to improve! + current_vocab = self.get_vocab().copy() + new_idx = len(current_vocab) # only call this once, len gives the last index + 1 + for token in new_tokens: + if not isinstance(token, (str, AddedToken)): + raise TypeError(f"Token {token} is not a string but a {type(token)}.") + if str(token) == "": + continue + if isinstance(token, str): + if token in self._added_tokens_encoder: + continue + else: + # very important for fast and slow equivalence! + is_special = token in self.all_special_tokens or special_tokens + token = AddedToken( + token, rstrip=False, lstrip=False, normalized=not is_special, special=is_special + ) + elif special_tokens: + # doing token.special=True changes the normalization! will fix in rust + # this is important and the only reason why the AddedTokens in each class are normalized by default + token.__setstate__({"special": True, "normalized": token.normalized}) + if token in self._added_tokens_decoder: + continue + if not token.special and token.normalized and getattr(self, "do_lower_case", False): + # Normalize if requested + token.content = token.content.lower() + if token.content not in current_vocab: + token_index = new_idx + added_tokens + current_vocab[token.content] = token_index + added_tokens += 1 + else: + token_index = current_vocab[token.content] + + if token.special and str(token) not in self.all_special_tokens: + self._special_tokens_map["additional_special_tokens"].append(token) + # the setter automatically updates the reverse map + self._added_tokens_decoder[token_index] = token + self._added_tokens_encoder[token.content] = token_index + if self.verbose: + logger.info(f"Adding {token} to the vocabulary") + + self._update_trie() + self._update_total_vocab_size() + return added_tokens + + def _update_trie(self, unique_no_split_tokens: Optional[list[str]] = None): + for token in self._added_tokens_decoder.values(): + if token.content not in self.tokens_trie._tokens: + self.tokens_trie.add(token.content) + for token in unique_no_split_tokens or []: + if token not in self.tokens_trie._tokens: + self.tokens_trie.add(token) + + def num_special_tokens_to_add(self, pair: bool = False) -> int: + """ + Returns the number of added tokens when encoding a sequence with special tokens. + + + + This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not put + this inside your training loop. + + + + Args: + pair (`bool`, *optional*, defaults to `False`): + Whether the number of added tokens should be computed in the case of a sequence pair or a single + sequence. + + Returns: + `int`: Number of special tokens added to sequences. + """ + token_ids_0 = [] + token_ids_1 = [] + return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None)) + + def tokenize(self, text: TextInput, **kwargs) -> list[str]: + """ + Converts a string into a sequence of tokens, using the tokenizer. + + Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies + (BPE/SentencePieces/WordPieces). Takes care of added tokens. + + Args: + text (`str`): + The sequence to be encoded. + **kwargs (additional keyword arguments): + Passed along to the model-specific `prepare_for_tokenization` preprocessing method. + + Returns: + `list[str]`: The list of tokens. + """ + split_special_tokens = kwargs.pop("split_special_tokens", self.split_special_tokens) + + text, kwargs = self.prepare_for_tokenization(text, **kwargs) + + if kwargs: + logger.warning(f"Keyword arguments {kwargs} not recognized.") + + if hasattr(self, "do_lower_case") and self.do_lower_case: + # convert non-special tokens to lowercase. Might be super slow as well? + escaped_special_toks = [re.escape(s_tok) for s_tok in (self.all_special_tokens)] + escaped_special_toks += [ + re.escape(s_tok.content) + for s_tok in (self._added_tokens_decoder.values()) + if not s_tok.special and s_tok.normalized + ] + pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" + text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) + + if split_special_tokens: + no_split_token = [] + tokens = [text] + else: + no_split_token = self._added_tokens_encoder.keys() # don't split on any of the added tokens + # "This is something else" + tokens = self.tokens_trie.split(text) + + # ["This is something", "", " else"] + for i, token in enumerate(tokens): + if token in no_split_token: + tok_extended = self._added_tokens_decoder.get(self._added_tokens_encoder[token], None) + left = tokens[i - 1] if i > 0 else None + right = tokens[i + 1] if i < len(tokens) - 1 else None + if isinstance(tok_extended, AddedToken): + if tok_extended.rstrip and right: + # A bit counter-intuitive but we strip the left of the string + # since tok_extended.rstrip means the special token is eating all white spaces on its right + tokens[i + 1] = right.lstrip() + # Strip white spaces on the left + if tok_extended.lstrip and left: + tokens[i - 1] = left.rstrip() # Opposite here + if tok_extended.single_word and left and left[-1] != " ": + tokens[i - 1] += token + tokens[i] = "" + elif tok_extended.single_word and right and right[0] != " ": + tokens[i + 1] = token + tokens[i + 1] + tokens[i] = "" + else: + raise ValueError( + f"{tok_extended} cannot be tokenized because it was not properly added" + f" to the tokenizer. This means that it is not an `AddedToken` but a {type(tok_extended)}" + ) + # ["This is something", "", "else"] + tokenized_text = [] + for token in tokens: + # Need to skip eventual empty (fully stripped) tokens + if not token: + continue + if token in no_split_token: + tokenized_text.append(token) + else: + tokenized_text.extend(self._tokenize(token)) + # ["This", " is", " something", "", "else"] + return tokenized_text + + def _tokenize(self, text, **kwargs): + """ + Converts a string into a sequence of tokens (string), using the tokenizer. Split in words for word-based + vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces). + + Do NOT take care of added tokens. + """ + raise NotImplementedError + + def convert_tokens_to_ids(self, tokens: Union[str, list[str]]) -> Union[int, list[int]]: + """ + Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the + vocabulary. + + Args: + tokens (`str` or `list[str]`): One or several token(s) to convert to token id(s). + + Returns: + `int` or `list[int]`: The token id or list of token ids. + """ + if tokens is None: + return None + + if isinstance(tokens, str): + return self._convert_token_to_id_with_added_voc(tokens) + + ids = [] + for token in tokens: + ids.append(self._convert_token_to_id_with_added_voc(token)) + return ids + + def _convert_token_to_id_with_added_voc(self, token): + if token is None: + return None + + if token in self._added_tokens_encoder: + return self._added_tokens_encoder[token] + return self._convert_token_to_id(token) + + def _convert_token_to_id(self, token): + raise NotImplementedError + + def _encode_plus( + self, + text: Union[TextInput, PreTokenizedInput, EncodedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + def get_input_ids(text): + if isinstance(text, str): + tokens = self.tokenize(text, **kwargs) + return self.convert_tokens_to_ids(tokens) + elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str): + if is_split_into_words: + tokens = list( + itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text)) + ) + return self.convert_tokens_to_ids(tokens) + else: + return self.convert_tokens_to_ids(text) + elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int): + return text + else: + if is_split_into_words: + raise ValueError( + f"Input {text} is not valid. Should be a string or a list/tuple of strings when" + " `is_split_into_words=True`." + ) + else: + raise ValueError( + f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of" + " integers." + ) + + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast. " + "More information on available tokenizers at " + "https://github.com/huggingface/transformers/pull/2674" + ) + + first_ids = get_input_ids(text) + second_ids = get_input_ids(text_pair) if text_pair is not None else None + + return self.prepare_for_model( + first_ids, + pair_ids=second_ids, + add_special_tokens=add_special_tokens, + padding=padding_strategy.value, + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + prepend_batch_axis=True, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + verbose=verbose, + ) + + def _batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + list[TextInput], + list[TextInputPair], + list[PreTokenizedInput], + list[PreTokenizedInputPair], + list[EncodedInput], + list[EncodedInputPair], + ], + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + split_special_tokens: bool = False, + **kwargs, + ) -> BatchEncoding: + def get_input_ids(text): + if isinstance(text, str): + tokens = self.tokenize(text, **kwargs) + return self.convert_tokens_to_ids(tokens) + elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str): + if is_split_into_words: + tokens = list( + itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text)) + ) + return self.convert_tokens_to_ids(tokens) + else: + return self.convert_tokens_to_ids(text) + elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int): + return text + else: + raise ValueError( + "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers." + ) + + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + ) + + input_ids = [] + for ids_or_pair_ids in batch_text_or_text_pairs: + if ( + not isinstance(ids_or_pair_ids, (list, tuple)) + or is_split_into_words + and not isinstance(ids_or_pair_ids[0], (list, tuple)) + ): + ids, pair_ids = ids_or_pair_ids, None + else: + ids, pair_ids = ids_or_pair_ids + + first_ids = get_input_ids(ids) + second_ids = get_input_ids(pair_ids) if pair_ids is not None else None + input_ids.append((first_ids, second_ids)) + + batch_outputs = self._batch_prepare_for_model( + input_ids, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=return_tensors, + verbose=verbose, + split_special_tokens=split_special_tokens, + ) + + return BatchEncoding(batch_outputs) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def _batch_prepare_for_model( + self, + batch_ids_pairs: list[Union[PreTokenizedInputPair, tuple[list[int], None]]], + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_length: bool = False, + verbose: bool = True, + split_special_tokens: bool = False, + ) -> BatchEncoding: + """ + Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It + adds special tokens, truncates sequences if overflowing while taking into account the special tokens and + manages a moving window (with user defined stride) for overflowing tokens + + Args: + batch_ids_pairs: list of tokenized input ids or input ids pairs + """ + + batch_outputs = {} + for first_ids, second_ids in batch_ids_pairs: + outputs = self.prepare_for_model( + first_ids, + second_ids, + add_special_tokens=add_special_tokens, + padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=None, # we pad in batch afterward + padding_side=None, # we pad in batch afterward + return_attention_mask=False, # we pad in batch afterward + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=None, # We convert the whole batch to tensors at the end + prepend_batch_axis=False, + verbose=verbose, + split_special_tokens=split_special_tokens, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + batch_outputs = self.pad( + batch_outputs, + padding=padding_strategy.value, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) + + return batch_outputs + + def prepare_for_tokenization( + self, text: str, is_split_into_words: bool = False, **kwargs + ) -> tuple[str, dict[str, Any]]: + """ + Performs any necessary transformations before tokenization. + + This method should pop the arguments from kwargs and return the remaining `kwargs` as well. We test the + `kwargs` at the end of the encoding process to be sure all the arguments have been used. + + Args: + text (`str`): + The text to prepare. + is_split_into_words (`bool`, *optional*, defaults to `False`): + Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the + tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace) + which it will tokenize. This is useful for NER or token classification. + kwargs (`dict[str, Any]`, *optional*): + Keyword arguments to use for the tokenization. + + Returns: + `tuple[str, dict[str, Any]]`: The prepared text and the unused kwargs. + """ + return (text, kwargs) + + def get_special_tokens_mask( + self, token_ids_0: list, token_ids_1: Optional[list] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. + + Args: + token_ids_0 (`list[int]`): + List of ids of the first sequence. + token_ids_1 (`list[int]`, *optional*): + List of ids of the second sequence. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0)) + + @overload + def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: ... + + @overload + def convert_ids_to_tokens(self, ids: list[int], skip_special_tokens: bool = False) -> list[str]: ... + + def convert_ids_to_tokens( + self, ids: Union[int, list[int]], skip_special_tokens: bool = False + ) -> Union[str, list[str]]: + """ + Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and + added tokens. + + Args: + ids (`int` or `list[int]`): + The token id (or token ids) to convert to tokens. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + + Returns: + `str` or `list[str]`: The decoded token(s). + """ + if isinstance(ids, int): + if ids in self._added_tokens_decoder: + return self._added_tokens_decoder[ids].content + else: + return self._convert_id_to_token(ids) + tokens = [] + for index in ids: + index = int(index) + if skip_special_tokens and index in self.all_special_ids: + continue + if index in self._added_tokens_decoder: + tokens.append(self._added_tokens_decoder[index].content) + else: + tokens.append(self._convert_id_to_token(index)) + return tokens + + def _convert_id_to_token(self, index: int) -> str: + raise NotImplementedError + + def convert_tokens_to_string(self, tokens: list[str]) -> str: + return " ".join(tokens) + + def _decode( + self, + token_ids: Union[int, list[int]], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = None, + spaces_between_special_tokens: bool = True, + **kwargs, + ) -> str: + self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) + + filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) + # If given is a single id, prevents splitting the string in upcoming loop + if isinstance(filtered_tokens, str): + filtered_tokens = [filtered_tokens] + + legacy_added_tokens = set(self._added_tokens_encoder.keys()) - set(self.all_special_tokens) | { + token for token in self.additional_special_tokens if self.convert_tokens_to_ids(token) >= self.vocab_size + } + # To avoid mixing byte-level and unicode for byte-level BPT + # we need to build string separately for added tokens and byte-level tokens + # cf. https://github.com/huggingface/transformers/issues/1133 + sub_texts = [] + current_sub_text = [] + # TODO @ArthurZ in version 5, special tokens should be handled in convert_tokens_to_string, while _convert_tokens_to_string + for token in filtered_tokens: + if skip_special_tokens and token in self.all_special_tokens: + continue + if token in legacy_added_tokens: + if current_sub_text: + string = self.convert_tokens_to_string(current_sub_text) + if len(string) > 0: + sub_texts.append(string) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_texts.append(self.convert_tokens_to_string(current_sub_text)) + + if spaces_between_special_tokens: + text = " ".join(sub_texts) + else: + text = "".join(sub_texts) + + clean_up_tokenization_spaces = ( + clean_up_tokenization_spaces + if clean_up_tokenization_spaces is not None + else self.clean_up_tokenization_spaces + ) + if clean_up_tokenization_spaces: + clean_text = self.clean_up_tokenization(text) + return clean_text + else: + return text diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tokenization_utils_base.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tokenization_utils_base.py new file mode 100644 index 0000000000000000000000000000000000000000..2dbbe296aa22c87e7dc4c13d1e1ccd4de390c6f0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tokenization_utils_base.py @@ -0,0 +1,4366 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Base classes common to both the slow and the fast tokenization classes: PreTrainedTokenizerBase (host all the user +fronting encoding methods) Special token mixing (host the special tokens logic) and BatchEncoding (wrap the dictionary +of output with special method for the Fast tokenizers) +""" + +import copy +import json +import os +import re +import warnings +from collections import UserDict +from collections.abc import Mapping, Sequence, Sized +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional, Union + +import numpy as np +from huggingface_hub import list_repo_files +from packaging import version + +from . import __version__ +from .dynamic_module_utils import custom_object_save +from .utils import ( + CHAT_TEMPLATE_DIR, + CHAT_TEMPLATE_FILE, + ExplicitEnum, + PaddingStrategy, + PushToHubMixin, + TensorType, + add_end_docstrings, + cached_file, + copy_func, + download_url, + extract_commit_hash, + is_flax_available, + is_jax_tensor, + is_mlx_available, + is_numpy_array, + is_offline_mode, + is_protobuf_available, + is_remote_url, + is_tf_available, + is_tf_tensor, + is_tokenizers_available, + is_torch_available, + is_torch_device, + is_torch_tensor, + list_repo_templates, + logging, + requires_backends, + to_py_obj, +) +from .utils.chat_template_utils import render_jinja_template +from .utils.import_utils import PROTOBUF_IMPORT_ERROR + + +if TYPE_CHECKING: + if is_torch_available(): + import torch + if is_tf_available(): + import tensorflow as tf + if is_flax_available(): + import jax.numpy as jnp # noqa: F401 + + +def import_protobuf_decode_error(error_message=""): + if is_protobuf_available(): + from google.protobuf.message import DecodeError + + return DecodeError + else: + raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message)) + + +def flatten(arr: list): + res = [] + if len(arr) > 0: + for sub_arr in arr: + if isinstance(arr[0], (list, tuple)): + res.extend(flatten(sub_arr)) + else: + res.append(sub_arr) + return res + + +if is_tokenizers_available() or TYPE_CHECKING: + from tokenizers import Encoding as EncodingFast + +if is_tokenizers_available(): + from tokenizers import AddedToken +else: + + @dataclass(frozen=False, eq=True) + class AddedToken: + """ + AddedToken represents a token to be added to a Tokenizer An AddedToken can have special options defining the + way it should behave. + + The `normalized` will default to `not special` if it is not specified, similarly to the definition in + `tokenizers`. + """ + + def __init__( + self, content: str, single_word=False, lstrip=False, rstrip=False, special=False, normalized=None + ): + self.content = content + self.single_word = single_word + self.lstrip = lstrip + self.rstrip = rstrip + self.special = special + self.normalized = normalized if normalized is not None else not special + + def __getstate__(self): + return self.__dict__ + + def __str__(self): + return self.content + + +logger = logging.get_logger(__name__) + +VERY_LARGE_INTEGER = int(1e30) # This is used to set the max input length for a model with infinite size input +LARGE_INTEGER = int(1e20) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER + +# Define type aliases and NamedTuples +TextInput = str +PreTokenizedInput = list[str] +EncodedInput = list[int] +TextInputPair = tuple[str, str] +PreTokenizedInputPair = tuple[list[str], list[str]] +EncodedInputPair = tuple[list[int], list[int]] + +# Define type aliases for text-related non-text modalities +AudioInput = Union[np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"]] + +# Slow tokenizers used to be saved in three separated files +SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" +ADDED_TOKENS_FILE = "added_tokens.json" +TOKENIZER_CONFIG_FILE = "tokenizer_config.json" + +# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file +FULL_TOKENIZER_FILE = "tokenizer.json" +_re_tokenizer_file = re.compile(r"tokenizer\.(.*)\.json") + + +class TruncationStrategy(ExplicitEnum): + """ + Possible values for the `truncation` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in + an IDE. + """ + + ONLY_FIRST = "only_first" + ONLY_SECOND = "only_second" + LONGEST_FIRST = "longest_first" + DO_NOT_TRUNCATE = "do_not_truncate" + + +class CharSpan(NamedTuple): + """ + Character span in the original string. + + Args: + start (`int`): Index of the first character in the original string. + end (`int`): Index of the character following the last character in the original string. + """ + + start: int + end: int + + +class TokenSpan(NamedTuple): + """ + Token span in an encoded string (list of tokens). + + Args: + start (`int`): Index of the first token in the span. + end (`int`): Index of the token following the last token in the span. + """ + + start: int + end: int + + +class BatchEncoding(UserDict): + """ + Holds the output of the [`~tokenization_utils_base.PreTrainedTokenizerBase.__call__`], + [`~tokenization_utils_base.PreTrainedTokenizerBase.encode_plus`] and + [`~tokenization_utils_base.PreTrainedTokenizerBase.batch_encode_plus`] methods (tokens, attention_masks, etc). + + This class is derived from a python dictionary and can be used as a dictionary. In addition, this class exposes + utility methods to map from word/character space to token space. + + Args: + data (`dict`, *optional*): + Dictionary of lists/arrays/tensors returned by the `__call__`/`encode_plus`/`batch_encode_plus` methods + ('input_ids', 'attention_mask', etc.). + encoding (`tokenizers.Encoding` or `Sequence[tokenizers.Encoding]`, *optional*): + If the tokenizer is a fast tokenizer which outputs additional information like mapping from word/character + space to token space the `tokenizers.Encoding` instance or list of instance (for batches) hold this + information. + tensor_type (`Union[None, str, TensorType]`, *optional*): + You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at + initialization. + prepend_batch_axis (`bool`, *optional*, defaults to `False`): + Whether or not to add a batch axis when converting to tensors (see `tensor_type` above). Note that this + parameter has an effect if the parameter `tensor_type` is set, *otherwise has no effect*. + n_sequences (`Optional[int]`, *optional*): + You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at + initialization. + """ + + def __init__( + self, + data: Optional[dict[str, Any]] = None, + encoding: Optional[Union[EncodingFast, Sequence[EncodingFast]]] = None, + tensor_type: Union[None, str, TensorType] = None, + prepend_batch_axis: bool = False, + n_sequences: Optional[int] = None, + ): + super().__init__(data) + + # If encoding is not None, the fast tokenization is used + if encoding is not None and isinstance(encoding, EncodingFast): + encoding = [encoding] + + self._encodings = encoding + + if n_sequences is None and encoding is not None and encoding: + n_sequences = encoding[0].n_sequences + + self._n_sequences = n_sequences + + self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis) + + @property + def n_sequences(self) -> Optional[int]: + """ + `Optional[int]`: The number of sequences used to generate each sample from the batch encoded in this + [`BatchEncoding`]. Currently can be one of `None` (unknown), `1` (a single sentence) or `2` (a pair of + sentences) + """ + return self._n_sequences + + @property + def is_fast(self) -> bool: + """ + `bool`: Indicate whether this [`BatchEncoding`] was generated from the result of a [`PreTrainedTokenizerFast`] + or not. + """ + return self._encodings is not None + + def __getitem__(self, item: Union[int, str]) -> Union[Any, EncodingFast]: + """ + If the key is a string, returns the value of the dict associated to `key` ('input_ids', 'attention_mask', + etc.). + + If the key is an integer, get the `tokenizers.Encoding` for batch item with index `key`. + + If the key is a slice, returns the value of the dict associated to `key` ('input_ids', 'attention_mask', etc.) + with the constraint of slice. + """ + if isinstance(item, str): + return self.data[item] + elif self._encodings is not None: + return self._encodings[item] + elif isinstance(item, slice): + return {key: self.data[key][item] for key in self.data} + else: + raise KeyError( + "Invalid key. Only three types of key are available: " + "(1) string, (2) integers for backend Encoding, and (3) slices for data subsetting." + ) + + def __getattr__(self, item: str): + try: + return self.data[item] + except KeyError: + raise AttributeError + + def __getstate__(self): + return {"data": self.data, "encodings": self._encodings} + + def __setstate__(self, state): + if "data" in state: + self.data = state["data"] + + if "encodings" in state: + self._encodings = state["encodings"] + + # After this point: + # Extended properties and methods only available for fast (Rust-based) tokenizers + # provided by HuggingFace tokenizers library. + + @property + def encodings(self) -> Optional[list[EncodingFast]]: + """ + `Optional[list[tokenizers.Encoding]]`: The list all encodings from the tokenization process. Returns `None` if + the input was tokenized through Python (i.e., not a fast) tokenizer. + """ + return self._encodings + + def tokens(self, batch_index: int = 0) -> list[str]: + """ + Return the list of tokens (sub-parts of the input strings after word/subword splitting and before conversion to + integer indices) at a given batch index (only works for the output of a fast tokenizer). + + Args: + batch_index (`int`, *optional*, defaults to 0): The index to access in the batch. + + Returns: + `list[str]`: The list of tokens at that index. + """ + if not self._encodings: + raise ValueError( + "tokens() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`" + " class)." + ) + return self._encodings[batch_index].tokens + + def sequence_ids(self, batch_index: int = 0) -> list[Optional[int]]: + """ + Return a list mapping the tokens to the id of their original sentences: + + - `None` for special tokens added around or between sequences, + - `0` for tokens corresponding to words in the first sequence, + - `1` for tokens corresponding to words in the second sequence when a pair of sequences was jointly + encoded. + + Args: + batch_index (`int`, *optional*, defaults to 0): The index to access in the batch. + + Returns: + `list[Optional[int]]`: A list indicating the sequence id corresponding to each token. Special tokens added + by the tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding + sequence. + """ + if not self._encodings: + raise ValueError( + "sequence_ids() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`" + " class)." + ) + return self._encodings[batch_index].sequence_ids + + def words(self, batch_index: int = 0) -> list[Optional[int]]: + """ + Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer. + + Args: + batch_index (`int`, *optional*, defaults to 0): The index to access in the batch. + + Returns: + `list[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by the + tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding word + (several tokens will be mapped to the same word index if they are parts of that word). + """ + if not self._encodings: + raise ValueError( + "words() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`" + " class)." + ) + warnings.warn( + "`BatchEncoding.words()` property is deprecated and should be replaced with the identical, " + "but more self-explanatory `BatchEncoding.word_ids()` property.", + FutureWarning, + ) + return self.word_ids(batch_index) + + def word_ids(self, batch_index: int = 0) -> list[Optional[int]]: + """ + Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer. + + Args: + batch_index (`int`, *optional*, defaults to 0): The index to access in the batch. + + Returns: + `list[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by the + tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding word + (several tokens will be mapped to the same word index if they are parts of that word). + """ + if not self._encodings: + raise ValueError( + "word_ids() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`" + " class)." + ) + return self._encodings[batch_index].word_ids + + def token_to_sequence(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int: + """ + Get the index of the sequence represented by the given token. In the general use case, this method returns `0` + for a single sequence or the first sequence of a pair, and `1` for the second sequence of a pair + + Can be called as: + + - `self.token_to_sequence(token_index)` if batch size is 1 + - `self.token_to_sequence(batch_index, token_index)` if batch size is greater than 1 + + This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e., + words are defined by the user). In this case it allows to easily associate encoded tokens with provided + tokenized words. + + Args: + batch_or_token_index (`int`): + Index of the sequence in the batch. If the batch only comprises one sequence, this can be the index of + the token in the sequence. + token_index (`int`, *optional*): + If a batch index is provided in *batch_or_token_index*, this can be the index of the token in the + sequence. + + Returns: + `int`: Index of the word in the input sequence. + """ + + if not self._encodings: + raise ValueError("token_to_sequence() is not available when using Python based tokenizers") + if token_index is not None: + batch_index = batch_or_token_index + else: + batch_index = 0 + token_index = batch_or_token_index + if batch_index < 0: + batch_index = self._batch_size + batch_index + if token_index < 0: + token_index = self._seq_len + token_index + return self._encodings[batch_index].token_to_sequence(token_index) + + def token_to_word(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int: + """ + Get the index of the word corresponding (i.e. comprising) to an encoded token in a sequence of the batch. + + Can be called as: + + - `self.token_to_word(token_index)` if batch size is 1 + - `self.token_to_word(batch_index, token_index)` if batch size is greater than 1 + + This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e., + words are defined by the user). In this case it allows to easily associate encoded tokens with provided + tokenized words. + + Args: + batch_or_token_index (`int`): + Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of + the token in the sequence. + token_index (`int`, *optional*): + If a batch index is provided in *batch_or_token_index*, this can be the index of the token in the + sequence. + + Returns: + `int`: Index of the word in the input sequence. + """ + + if not self._encodings: + raise ValueError("token_to_word() is not available when using Python based tokenizers") + if token_index is not None: + batch_index = batch_or_token_index + else: + batch_index = 0 + token_index = batch_or_token_index + if batch_index < 0: + batch_index = self._batch_size + batch_index + if token_index < 0: + token_index = self._seq_len + token_index + return self._encodings[batch_index].token_to_word(token_index) + + def word_to_tokens( + self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0 + ) -> Optional[TokenSpan]: + """ + Get the encoded token span corresponding to a word in a sequence of the batch. + + Token spans are returned as a [`~tokenization_utils_base.TokenSpan`] with: + + - **start** -- Index of the first token. + - **end** -- Index of the token following the last token. + + Can be called as: + + - `self.word_to_tokens(word_index, sequence_index: int = 0)` if batch size is 1 + - `self.word_to_tokens(batch_index, word_index, sequence_index: int = 0)` if batch size is greater or equal to + 1 + + This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words + are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized + words. + + Args: + batch_or_word_index (`int`): + Index of the sequence in the batch. If the batch only comprises one sequence, this can be the index of + the word in the sequence. + word_index (`int`, *optional*): + If a batch index is provided in *batch_or_token_index*, this can be the index of the word in the + sequence. + sequence_index (`int`, *optional*, defaults to 0): + If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0 + or 1) the provided word index belongs to. + + Returns: + ([`~tokenization_utils_base.TokenSpan`], *optional*): Span of tokens in the encoded sequence. Returns + `None` if no tokens correspond to the word. This can happen especially when the token is a special token + that has been used to format the tokenization. For example when we add a class token at the very beginning + of the tokenization. + """ + + if not self._encodings: + raise ValueError("word_to_tokens() is not available when using Python based tokenizers") + if word_index is not None: + batch_index = batch_or_word_index + else: + batch_index = 0 + word_index = batch_or_word_index + if batch_index < 0: + batch_index = self._batch_size + batch_index + if word_index < 0: + word_index = self._seq_len + word_index + span = self._encodings[batch_index].word_to_tokens(word_index, sequence_index) + return TokenSpan(*span) if span is not None else None + + def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> Optional[CharSpan]: + """ + Get the character span corresponding to an encoded token in a sequence of the batch. + + Character spans are returned as a [`~tokenization_utils_base.CharSpan`] with: + + - **start** -- Index of the first character in the original string associated to the token. + - **end** -- Index of the character following the last character in the original string associated to the + token. + + Can be called as: + + - `self.token_to_chars(token_index)` if batch size is 1 + - `self.token_to_chars(batch_index, token_index)` if batch size is greater or equal to 1 + + Args: + batch_or_token_index (`int`): + Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of + the token in the sequence. + token_index (`int`, *optional*): + If a batch index is provided in *batch_or_token_index*, this can be the index of the token or tokens in + the sequence. + + Returns: + [`~tokenization_utils_base.CharSpan`]: Span of characters in the original string, or None, if the token + (e.g. , ) doesn't correspond to any chars in the origin string. + """ + + if not self._encodings: + raise ValueError("token_to_chars() is not available when using Python based tokenizers") + if token_index is not None: + batch_index = batch_or_token_index + else: + batch_index = 0 + token_index = batch_or_token_index + span_indices = self._encodings[batch_index].token_to_chars(token_index) + + return CharSpan(*span_indices) if span_indices is not None else None + + def char_to_token( + self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0 + ) -> int: + """ + Get the index of the token in the encoded output comprising a character in the original string for a sequence + of the batch. + + Can be called as: + + - `self.char_to_token(char_index)` if batch size is 1 + - `self.char_to_token(batch_index, char_index)` if batch size is greater or equal to 1 + + This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words + are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized + words. + + Args: + batch_or_char_index (`int`): + Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of + the word in the sequence + char_index (`int`, *optional*): + If a batch index is provided in *batch_or_token_index*, this can be the index of the word in the + sequence. + sequence_index (`int`, *optional*, defaults to 0): + If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0 + or 1) the provided character index belongs to. + + + Returns: + `int`: Index of the token, or None if the char index refers to a whitespace only token and whitespace is + trimmed with `trim_offsets=True`. + """ + + if not self._encodings: + raise ValueError("char_to_token() is not available when using Python based tokenizers") + if char_index is not None: + batch_index = batch_or_char_index + else: + batch_index = 0 + char_index = batch_or_char_index + return self._encodings[batch_index].char_to_token(char_index, sequence_index) + + def word_to_chars( + self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0 + ) -> CharSpan: + """ + Get the character span in the original string corresponding to given word in a sequence of the batch. + + Character spans are returned as a CharSpan NamedTuple with: + + - start: index of the first character in the original string + - end: index of the character following the last character in the original string + + Can be called as: + + - `self.word_to_chars(word_index)` if batch size is 1 + - `self.word_to_chars(batch_index, word_index)` if batch size is greater or equal to 1 + + Args: + batch_or_word_index (`int`): + Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of + the word in the sequence + word_index (`int`, *optional*): + If a batch index is provided in *batch_or_token_index*, this can be the index of the word in the + sequence. + sequence_index (`int`, *optional*, defaults to 0): + If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0 + or 1) the provided word index belongs to. + + Returns: + `CharSpan` or `list[CharSpan]`: Span(s) of the associated character or characters in the string. CharSpan + are NamedTuple with: + + - start: index of the first character associated to the token in the original string + - end: index of the character following the last character associated to the token in the original + string + """ + + if not self._encodings: + raise ValueError("word_to_chars() is not available when using Python based tokenizers") + if word_index is not None: + batch_index = batch_or_word_index + else: + batch_index = 0 + word_index = batch_or_word_index + return CharSpan(*(self._encodings[batch_index].word_to_chars(word_index, sequence_index))) + + def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0) -> int: + """ + Get the word in the original string corresponding to a character in the original string of a sequence of the + batch. + + Can be called as: + + - `self.char_to_word(char_index)` if batch size is 1 + - `self.char_to_word(batch_index, char_index)` if batch size is greater than 1 + + This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words + are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized + words. + + Args: + batch_or_char_index (`int`): + Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of + the character in the original string. + char_index (`int`, *optional*): + If a batch index is provided in *batch_or_token_index*, this can be the index of the character in the + original string. + sequence_index (`int`, *optional*, defaults to 0): + If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0 + or 1) the provided character index belongs to. + + + Returns: + `int` or `list[int]`: Index or indices of the associated encoded token(s). + """ + + if not self._encodings: + raise ValueError("char_to_word() is not available when using Python based tokenizers") + if char_index is not None: + batch_index = batch_or_char_index + else: + batch_index = 0 + char_index = batch_or_char_index + return self._encodings[batch_index].char_to_word(char_index, sequence_index) + + def convert_to_tensors( + self, tensor_type: Optional[Union[str, TensorType]] = None, prepend_batch_axis: bool = False + ): + """ + Convert the inner content to tensors. + + Args: + tensor_type (`str` or [`~utils.TensorType`], *optional*): + The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If + `None`, no modification is done. + prepend_batch_axis (`int`, *optional*, defaults to `False`): + Whether or not to add the batch dimension during the conversion. + """ + if tensor_type is None: + return self + + # Convert to TensorType + if not isinstance(tensor_type, TensorType): + tensor_type = TensorType(tensor_type) + + # Get a function reference for the correct framework + if tensor_type == TensorType.TENSORFLOW: + if not is_tf_available(): + raise ImportError( + "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed." + ) + import tensorflow as tf + + def as_tensor(value, dtype=None): + if len(flatten(value)) == 0 and dtype is None: + dtype = tf.int32 + return tf.constant(value, dtype=dtype) + + is_tensor = tf.is_tensor + + elif tensor_type == TensorType.PYTORCH: + if not is_torch_available(): + raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") + import torch + + def as_tensor(value, dtype=None): + if isinstance(value, list) and len(value) > 0 and isinstance(value[0], np.ndarray): + return torch.from_numpy(np.array(value)) + if len(flatten(value)) == 0 and dtype is None: + dtype = torch.int64 + return torch.tensor(value, dtype=dtype) + + is_tensor = torch.is_tensor + + elif tensor_type == TensorType.JAX: + if not is_flax_available(): + raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.") + import jax.numpy as jnp # noqa: F811 + + def as_tensor(value, dtype=None): + if len(flatten(value)) == 0 and dtype is None: + dtype = jnp.int32 + return jnp.array(value, dtype=dtype) + + is_tensor = is_jax_tensor + + elif tensor_type == TensorType.MLX: + if not is_mlx_available(): + raise ImportError("Unable to convert output to MLX tensors format, MLX is not installed.") + import mlx.core as mx + + def as_tensor(value, dtype=None): + if len(flatten(value)) == 0 and dtype is None: + dtype = mx.int32 + return mx.array(value, dtype=dtype) + + def is_tensor(obj): + return isinstance(obj, mx.array) + else: + + def as_tensor(value, dtype=None): + if ( + isinstance(value, (list, tuple)) + and len(value) > 0 + and isinstance(value[0], (list, tuple, np.ndarray)) + ): + value_lens = [len(val) for val in value] + if len(set(value_lens)) > 1 and dtype is None: + # we have a ragged list so handle explicitly + value = as_tensor([np.asarray(val) for val in value], dtype=object) + if len(flatten(value)) == 0 and dtype is None: + dtype = np.int64 + return np.asarray(value, dtype=dtype) + + is_tensor = is_numpy_array + + # Do the tensor conversion in batch + for key, value in self.items(): + try: + if prepend_batch_axis: + value = [value] + + if not is_tensor(value): + tensor = as_tensor(value) + + # Removing this for now in favor of controlling the shape with `prepend_batch_axis` + # # at-least2d + # if tensor.ndim > 2: + # tensor = tensor.squeeze(0) + # elif tensor.ndim < 2: + # tensor = tensor[None, :] + + self[key] = tensor + except Exception as e: + if key == "overflowing_tokens": + raise ValueError( + "Unable to create tensor returning overflowing tokens of different lengths. " + "Please see if a fast version of this tokenizer is available to have this feature available." + ) from e + raise ValueError( + "Unable to create tensor, you should probably activate truncation and/or padding with" + " 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your" + f" features (`{key}` in this case) have excessive nesting (inputs type `list` where type `int` is" + " expected)." + ) from e + + return self + + def to(self, device: Union[str, "torch.device"], *, non_blocking: bool = False) -> "BatchEncoding": + """ + Send all values to device by calling `v.to(device, non_blocking=non_blocking)` (PyTorch only). + + Args: + device (`str` or `torch.device`): The device to put the tensors on. + non_blocking (`bool`): Whether to perform the copy asynchronously. + + Returns: + [`BatchEncoding`]: The same instance after modification. + """ + requires_backends(self, ["torch"]) + + # This check catches things like APEX blindly calling "to" on all inputs to a module + # Otherwise it passes the casts down and casts the LongTensor containing the token idxs + # into a HalfTensor + if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): + self.data = { + k: v.to(device=device, non_blocking=non_blocking) if hasattr(v, "to") and callable(v.to) else v + for k, v in self.data.items() + } + else: + logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.") + return self + + +class SpecialTokensMixin: + """ + A mixin derived by [`PreTrainedTokenizer`] and [`PreTrainedTokenizerFast`] to handle specific behaviors related to + special tokens. In particular, this class hold the attributes which can be used to directly access these special + tokens in a model-independent manner and allow to set and update the special tokens. + + Args: + bos_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token representing the beginning of a sentence. + eos_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token representing the end of a sentence. + unk_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token representing an out-of-vocabulary token. + sep_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token separating two different sentences in the same input (used by BERT for instance). + pad_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by + attention mechanisms or loss computation. + cls_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token representing the class of the input (used by BERT for instance). + mask_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token representing a masked token (used by masked-language modeling pretraining objectives, like + BERT). + additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*): + A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be + skipped when decoding if `skip_special_tokens` is set to `True`. + """ + + SPECIAL_TOKENS_ATTRIBUTES = [ + "bos_token", + "eos_token", + "unk_token", + "sep_token", + "pad_token", + "cls_token", + "mask_token", + "additional_special_tokens", + ] + + def __init__(self, verbose=False, **kwargs): + self._pad_token_type_id = 0 + self.verbose = verbose + self._special_tokens_map = dict.fromkeys(self.SPECIAL_TOKENS_ATTRIBUTES) + self._special_tokens_map["additional_special_tokens"] = [] # for BC where it defaults to empty list + + # We directly set the hidden value to allow initialization with special tokens + # which are not yet in the vocabulary. Necessary for serialization/de-serialization + # TODO clean this up at some point (probably by switching to fast tokenizers) + + for key, value in kwargs.items(): + if value is None: + continue + if key in self.SPECIAL_TOKENS_ATTRIBUTES: + if key == "additional_special_tokens": + assert isinstance(value, (list, tuple)), f"Value {value} is not a list or tuple" + assert all(isinstance(t, (str, AddedToken)) for t in value), ( + "One of the tokens is not a string or an AddedToken" + ) + setattr(self, key, value) + elif isinstance(value, (str, AddedToken)): + setattr(self, key, value) + else: + raise TypeError(f"Special token {key} has to be either str or AddedToken but got: {type(value)}") + + def sanitize_special_tokens(self) -> int: + """ + The `sanitize_special_tokens` is now deprecated kept for backward compatibility and will be removed in + transformers v5. + """ + logger.warning_once("The `sanitize_special_tokens` will be removed in transformers v5.") + return self.add_tokens(self.all_special_tokens_extended, special_tokens=True) + + def add_special_tokens( + self, + special_tokens_dict: dict[str, Union[str, AddedToken, Sequence[Union[str, AddedToken]]]], + replace_additional_special_tokens=True, + ) -> int: + """ + Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If + special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the + current vocabulary). + + When adding new tokens to the vocabulary, you should make sure to also resize the token embedding matrix of the + model so that its embedding matrix matches the tokenizer. + + In order to do that, please use the [`~PreTrainedModel.resize_token_embeddings`] method. + + Using `add_special_tokens` will ensure your special tokens can be used in several ways: + + - Special tokens can be skipped when decoding using `skip_special_tokens = True`. + - Special tokens are carefully handled by the tokenizer (they are never split), similar to `AddedTokens`. + - You can easily refer to special tokens using tokenizer class attributes like `tokenizer.cls_token`. This + makes it easy to develop model-agnostic training and fine-tuning scripts. + + When possible, special tokens are already registered for provided pretrained models (for instance + [`BertTokenizer`] `cls_token` is already registered to be `'[CLS]'` and XLM's one is also registered to be + `''`). + + Args: + special_tokens_dict (dictionary *str* to *str*, `tokenizers.AddedToken`, or `Sequence[Union[str, AddedToken]]`): + Keys should be in the list of predefined special attributes: [`bos_token`, `eos_token`, `unk_token`, + `sep_token`, `pad_token`, `cls_token`, `mask_token`, `additional_special_tokens`]. + + Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer + assign the index of the `unk_token` to them). + replace_additional_special_tokens (`bool`, *optional*,, defaults to `True`): + If `True`, the existing list of additional special tokens will be replaced by the list provided in + `special_tokens_dict`. Otherwise, `self._special_tokens_map["additional_special_tokens"]` is just extended. In the former + case, the tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged + as non-special tokens. Remember, this only affects which tokens are skipped during decoding, not the + `added_tokens_encoder` and `added_tokens_decoder`. This means that the previous + `additional_special_tokens` are still added tokens, and will not be split by the model. + + Returns: + `int`: Number of tokens added to the vocabulary. + + Examples: + + ```python + # Let's see how to add a new classification token to GPT-2 + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + model = GPT2Model.from_pretrained("openai-community/gpt2") + + special_tokens_dict = {"cls_token": ""} + + num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) + print("We have added", num_added_toks, "tokens") + # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer. + model.resize_token_embeddings(len(tokenizer)) + + assert tokenizer.cls_token == "" + ```""" + if not special_tokens_dict: + return 0 + + added_tokens = [] + for key, value in special_tokens_dict.items(): + assert key in self.SPECIAL_TOKENS_ATTRIBUTES, f"Key {key} is not a special token" + + if self.verbose: + logger.info(f"Assigning {value} to the {key} key of the tokenizer") + + if key == "additional_special_tokens": + assert isinstance(value, (list, tuple)) and all(isinstance(t, (str, AddedToken)) for t in value), ( + f"Tokens {value} for key {key} should all be str or AddedToken instances" + ) + + to_add = [] + for token in value: + if isinstance(token, str): + # for legacy purpose we default to stripping. `test_add_tokens_tokenizer` depends on this + token = AddedToken(token, rstrip=False, lstrip=False, normalized=False, special=True) + if not replace_additional_special_tokens and str(token) in self.additional_special_tokens: + continue + to_add.append(token) + if replace_additional_special_tokens and len(to_add) > 0: + setattr(self, key, list(to_add)) + else: + self._special_tokens_map["additional_special_tokens"].extend(to_add) + added_tokens += to_add + + else: + if not isinstance(value, (str, AddedToken)): + raise ValueError(f"Token {value} for key {key} should be a str or an AddedToken instance") + if isinstance(value, (str)): + # for legacy purpose we default to stripping. `False` depends on this + value = AddedToken(value, rstrip=False, lstrip=False, normalized=False, special=True) + if isinstance(value, AddedToken): + setattr(self, key, value) + if value not in added_tokens: + added_tokens.append(value) + + # if we are adding tokens that were not part of the vocab, we ought to add them + added_tokens = self.add_tokens(added_tokens, special_tokens=True) + return added_tokens + + def add_tokens( + self, new_tokens: Union[str, AddedToken, Sequence[Union[str, AddedToken]]], special_tokens: bool = False + ) -> int: + """ + Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to + it with indices starting from length of the current vocabulary and will be isolated before the tokenization + algorithm is applied. Added tokens and tokens from the vocabulary of the tokenization algorithm are therefore + not treated in the same way. + + Note, when adding new tokens to the vocabulary, you should make sure to also resize the token embedding matrix + of the model so that its embedding matrix matches the tokenizer. + + In order to do that, please use the [`~PreTrainedModel.resize_token_embeddings`] method. + + Args: + new_tokens (`str`, `tokenizers.AddedToken` or a sequence of *str* or `tokenizers.AddedToken`): + Tokens are only added if they are not already in the vocabulary. `tokenizers.AddedToken` wraps a string + token to let you personalize its behavior: whether this token should only match against a single word, + whether this token should strip all potential whitespaces on the left side, whether this token should + strip all potential whitespaces on the right side, etc. + special_tokens (`bool`, *optional*, defaults to `False`): + Can be used to specify if the token is a special token. This mostly change the normalization behavior + (special tokens like CLS or [MASK] are usually not lower-cased for instance). + + See details for `tokenizers.AddedToken` in HuggingFace tokenizers library. + + Returns: + `int`: Number of tokens added to the vocabulary. + + Examples: + + ```python + # Let's see how to increase the vocabulary of Bert model and tokenizer + tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-uncased") + model = BertModel.from_pretrained("google-bert/bert-base-uncased") + + num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"]) + print("We have added", num_added_toks, "tokens") + # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer. + model.resize_token_embeddings(len(tokenizer)) + ```""" + if not new_tokens: + return 0 + + if not isinstance(new_tokens, (list, tuple)): + new_tokens = [new_tokens] + + return self._add_tokens(new_tokens, special_tokens=special_tokens) + + def _add_tokens(self, new_tokens: Union[list[str], list[AddedToken]], special_tokens: bool = False) -> int: + raise NotImplementedError + + @property + def pad_token_type_id(self) -> int: + """ + `int`: Id of the padding token type in the vocabulary. + """ + return self._pad_token_type_id + + def __setattr__(self, key, value): + key_without_id = key + key_is_special_id = key.endswith("_id") or key.endswith("_ids") + if key_is_special_id: + key_without_id = key[:-3] if not key.endswith("_ids") else key[:-4] + + if self.__dict__.get("_special_tokens_map", None) is not None and any( + name in self.__dict__["_special_tokens_map"] for name in [key, key_without_id] + ): + if key_is_special_id: + if value is not None: + value = ( + self.convert_ids_to_tokens(value) + if key != "additional_special_tokens" + else [self.convert_ids_to_tokens(val) for val in value] + ) + key = key_without_id + + if key != "additional_special_tokens" and not isinstance(value, (str, AddedToken)) and value is not None: + raise ValueError(f"Cannot set a non-string value as the {key}") + self._special_tokens_map[key] = value + else: + super().__setattr__(key, value) + + def __getattr__(self, key): + key_without_id = key + key_is_special_id = key.endswith("_id") or key.endswith("_ids") + if key_is_special_id: + key_without_id = key[:-3] if not key.endswith("_ids") else key[:-4] + + if self.__dict__.get("_special_tokens_map", None) is not None and any( + name in self.__dict__["_special_tokens_map"] for name in [key, key_without_id] + ): + _special_tokens_map = self.__dict__["_special_tokens_map"] + if not key_is_special_id: + if _special_tokens_map[key] is None: + if self.verbose: + logger.error(f"Using {key}, but it is not set yet.") + return None + value = _special_tokens_map[key] + return str(value) if key != "additional_special_tokens" else [str(tok) for tok in value] + else: + attr_as_tokens = getattr(self, key_without_id) + return self.convert_tokens_to_ids(attr_as_tokens) if attr_as_tokens is not None else None + + if key not in self.__dict__: + raise AttributeError(f"{self.__class__.__name__} has no attribute {key}") + else: + return super().__getattr__(key) + + @property + def special_tokens_map(self) -> dict[str, Union[str, list[str]]]: + """ + `dict[str, Union[str, list[str]]]`: A dictionary mapping special token class attributes (`cls_token`, + `unk_token`, etc.) to their values (`''`, `''`, etc.). + + Convert potential tokens of `tokenizers.AddedToken` type to string. + """ + set_attr = {} + for attr in self.SPECIAL_TOKENS_ATTRIBUTES: + attr_value = getattr(self, attr) + if attr_value: + set_attr[attr] = attr_value + return set_attr + + @property + def special_tokens_map_extended(self) -> dict[str, Union[str, AddedToken, list[Union[str, AddedToken]]]]: + """ + `dict[str, Union[str, tokenizers.AddedToken, list[Union[str, tokenizers.AddedToken]]]]`: A dictionary mapping + special token class attributes (`cls_token`, `unk_token`, etc.) to their values (`''`, `''`, etc.). + + Don't convert tokens of `tokenizers.AddedToken` type to string so they can be used to control more finely how + special tokens are tokenized. + """ + set_attr = {} + for attr in self.SPECIAL_TOKENS_ATTRIBUTES: + attr_value = self._special_tokens_map[attr] + if attr_value: + set_attr[attr] = attr_value + return set_attr + + @property + def all_special_tokens_extended(self) -> list[Union[str, AddedToken]]: + """ + `list[Union[str, tokenizers.AddedToken]]`: All the special tokens (`''`, `''`, etc.), the order has + nothing to do with the index of each tokens. If you want to know the correct indices, check + `self.added_tokens_encoder`. We can't create an order anymore as the keys are `AddedTokens` and not `Strings`. + + Don't convert tokens of `tokenizers.AddedToken` type to string so they can be used to control more finely how + special tokens are tokenized. + """ + all_tokens = [] + seen = set() + for value in self.special_tokens_map_extended.values(): + if isinstance(value, (list, tuple)): + tokens_to_add = [token for token in value if str(token) not in seen] + else: + tokens_to_add = [value] if str(value) not in seen else [] + seen.update(map(str, tokens_to_add)) + all_tokens.extend(tokens_to_add) + return all_tokens + + @property + def all_special_tokens(self) -> list[str]: + """ + `list[str]`: A list of the unique special tokens (`''`, `''`, ..., etc.). + + Convert tokens of `tokenizers.AddedToken` type to string. + """ + all_toks = [str(s) for s in self.all_special_tokens_extended] + return all_toks + + @property + def all_special_ids(self) -> list[int]: + """ + `list[int]`: List the ids of the special tokens(`''`, `''`, etc.) mapped to class attributes. + """ + all_toks = self.all_special_tokens + all_ids = self.convert_tokens_to_ids(all_toks) + return all_ids + + def _set_model_specific_special_tokens(self, special_tokens: list[str]): + """ + Adds new special tokens to the "SPECIAL_TOKENS_ATTRIBUTES" list which will be part + of "self.special_tokens" and saved as a special token in tokenizer's config. + This allows us to dynamically add new model-type specific tokens after initializing the tokenizer. + For example: if the model tokenizers is multimodal, we can support special image or audio tokens. + """ + self.SPECIAL_TOKENS_ATTRIBUTES = self.SPECIAL_TOKENS_ATTRIBUTES + list(special_tokens.keys()) + for key, value in special_tokens.items(): + if isinstance(value, (str, AddedToken)): + self._special_tokens_map[key] = value + else: + raise TypeError(f"Special token {key} has to be either str or AddedToken but got: {type(value)}") + + +ENCODE_KWARGS_DOCSTRING = r""" + add_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to add special tokens when encoding the sequences. This will use the underlying + `PretrainedTokenizerBase.build_inputs_with_special_tokens` function, which defines which tokens are + automatically added to the input ids. This is useful if you want to add `bos` or `eos` tokens + automatically. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence is provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): + Activates and controls truncation. Accepts the following values: + + - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will + truncate token by token, removing a token from the longest sequence in the pair if a pair of + sequences (or a batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. + + If left unset or set to `None`, this will use the predefined model maximum length if a maximum length + is required by one of the truncation/padding parameters. If the model has no specific maximum input + length (like XLNet) truncation/padding to a maximum length will be deactivated. + stride (`int`, *optional*, defaults to 0): + If set to a number along with `max_length`, the overflowing tokens returned when + `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence + returned to provide some overlap between truncated and overflowing sequences. The value of this + argument defines the number of overlapping tokens. + is_split_into_words (`bool`, *optional*, defaults to `False`): + Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the + tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace) + which it will tokenize. This is useful for NER or token classification. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. Requires `padding` to be activated. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + padding_side (`str`, *optional*): + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. +""" + +ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r""" + return_token_type_ids (`bool`, *optional*): + Whether to return token type IDs. If left to the default, will return the token type IDs according to + the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are token type IDs?](../glossary#token-type-ids) + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are attention masks?](../glossary#attention-mask) + return_overflowing_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch + of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead + of returning overflowing tokens. + return_special_tokens_mask (`bool`, *optional*, defaults to `False`): + Whether or not to return special tokens mask information. + return_offsets_mapping (`bool`, *optional*, defaults to `False`): + Whether or not to return `(char_start, char_end)` for each token. + + This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using + Python's tokenizer, this method will raise `NotImplementedError`. + return_length (`bool`, *optional*, defaults to `False`): + Whether or not to return the lengths of the encoded inputs. + verbose (`bool`, *optional*, defaults to `True`): + Whether or not to print more information and warnings. + **kwargs: passed to the `self.tokenize()` method + + Return: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + + [What are input IDs?](../glossary#input-ids) + + - **token_type_ids** -- List of token type ids to be fed to a model (when `return_token_type_ids=True` or + if *"token_type_ids"* is in `self.model_input_names`). + + [What are token type IDs?](../glossary#token-type-ids) + + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`). + + [What are attention masks?](../glossary#attention-mask) + + - **overflowing_tokens** -- List of overflowing tokens sequences (when a `max_length` is specified and + `return_overflowing_tokens=True`). + - **num_truncated_tokens** -- Number of tokens truncated (when a `max_length` is specified and + `return_overflowing_tokens=True`). + - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying + regular sequence tokens (when `add_special_tokens=True` and `return_special_tokens_mask=True`). + - **length** -- The length of the inputs (when `return_length=True`) +""" + + +INIT_TOKENIZER_DOCSTRING = r""" + Class attributes (overridden by derived classes) + + - **vocab_files_names** (`dict[str, str]`) -- A dictionary with, as keys, the `__init__` keyword name of each + vocabulary file required by the model, and as associated values, the filename for saving the associated file + (string). + - **pretrained_vocab_files_map** (`dict[str, dict[str, str]]`) -- A dictionary of dictionaries, with the + high-level keys being the `__init__` keyword name of each vocabulary file required by the model, the + low-level being the `short-cut-names` of the pretrained models with, as associated values, the `url` to the + associated pretrained vocabulary file. + - **model_input_names** (`list[str]`) -- A list of inputs expected in the forward pass of the model. + - **padding_side** (`str`) -- The default value for the side on which the model should have padding applied. + Should be `'right'` or `'left'`. + - **truncation_side** (`str`) -- The default value for the side on which the model should have truncation + applied. Should be `'right'` or `'left'`. + + Args: + model_max_length (`int`, *optional*): + The maximum length (in number of tokens) for the inputs to the transformer model. When the tokenizer is + loaded with [`~tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`], this will be set to the + value stored for the associated model in `max_model_input_sizes` (see above). If no value is provided, will + default to VERY_LARGE_INTEGER (`int(1e30)`). + padding_side (`str`, *optional*): + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + truncation_side (`str`, *optional*): + The side on which the model should have truncation applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + chat_template (`str`, *optional*): + A Jinja template string that will be used to format lists of chat messages. See + https://huggingface.co/docs/transformers/chat_templating for a full description. + model_input_names (`list[string]`, *optional*): + The list of inputs accepted by the forward pass of the model (like `"token_type_ids"` or + `"attention_mask"`). Default value is picked from the class attribute of the same name. + bos_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token representing the beginning of a sentence. Will be associated to `self.bos_token` and + `self.bos_token_id`. + eos_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token representing the end of a sentence. Will be associated to `self.eos_token` and + `self.eos_token_id`. + unk_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token representing an out-of-vocabulary token. Will be associated to `self.unk_token` and + `self.unk_token_id`. + sep_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token separating two different sentences in the same input (used by BERT for instance). Will be + associated to `self.sep_token` and `self.sep_token_id`. + pad_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by + attention mechanisms or loss computation. Will be associated to `self.pad_token` and `self.pad_token_id`. + cls_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token representing the class of the input (used by BERT for instance). Will be associated to + `self.cls_token` and `self.cls_token_id`. + mask_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token representing a masked token (used by masked-language modeling pretraining objectives, like + BERT). Will be associated to `self.mask_token` and `self.mask_token_id`. + additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*): + A tuple or a list of additional special tokens. Add them here to ensure they are skipped when decoding with + `skip_special_tokens` is set to True. If they are not part of the vocabulary, they will be added at the end + of the vocabulary. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): + Whether or not the model should cleanup the spaces that were added when splitting the input text during the + tokenization process. + split_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the special tokens should be split during the tokenization process. Passing will affect the + internal state of the tokenizer. The default behavior is to not split special tokens. This means that if + `` is the `bos_token`, then `tokenizer.tokenize("") = ['`]. Otherwise, if + `split_special_tokens=True`, then `tokenizer.tokenize("")` will be give `['<','s', '>']`. +""" + + +@add_end_docstrings(INIT_TOKENIZER_DOCSTRING) +class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): + """ + Base class for [`PreTrainedTokenizer`] and [`PreTrainedTokenizerFast`]. + + Handles shared (mostly boiler plate) methods for those two classes. + """ + + vocab_files_names: dict[str, str] = {} + pretrained_vocab_files_map: dict[str, dict[str, str]] = {} + _auto_class: Optional[str] = None + + # first name has to correspond to main model input name + # to make sure `tokenizer.pad(...)` works correctly + model_input_names: list[str] = ["input_ids", "token_type_ids", "attention_mask"] + padding_side: str = "right" + truncation_side: str = "right" + slow_tokenizer_class = None + + def __init__(self, **kwargs): + # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) + self.init_inputs = () + for key in kwargs: + if hasattr(self, key) and callable(getattr(self, key)): + raise AttributeError(f"{key} conflicts with the method {key} in {self.__class__.__name__}") + + self.init_kwargs = copy.deepcopy(kwargs) + self.name_or_path = kwargs.pop("name_or_path", "") + self._processor_class = kwargs.pop("processor_class", None) + + # For backward compatibility we fallback to set model_max_length from max_len if provided + model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None)) + self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER + + # Padding and truncation side are right by default and overridden in subclasses. If specified in the kwargs, it + # is changed. + self.padding_side = kwargs.pop("padding_side", self.padding_side) + if self.padding_side not in ["right", "left"]: + raise ValueError( + f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}" + ) + + self.truncation_side = kwargs.pop("truncation_side", self.truncation_side) + if self.truncation_side not in ["right", "left"]: + raise ValueError( + f"Truncation side should be selected between 'right' and 'left', current value: {self.truncation_side}" + ) + + self.model_input_names = kwargs.pop("model_input_names", self.model_input_names) + + # By default, cleaning tokenization spaces for both fast and slow tokenizers + self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", False) + + # By default, do not split special tokens for both fast and slow tokenizers + self.split_special_tokens = kwargs.pop("split_special_tokens", False) + + self.deprecation_warnings = {} # Use to store when we have already noticed a deprecation warning (avoid overlogging). + self._in_target_context_manager = False + + # Stores a Jinja template that formats chat histories into tokenizable strings + self.chat_template = kwargs.pop("chat_template", None) + if isinstance(self.chat_template, (list, tuple)): + # Chat templates are stored as lists of dicts with fixed key names, + # we reconstruct that into a single dict while loading them. + self.chat_template = {template["name"]: template["template"] for template in self.chat_template} + + super().__init__(**kwargs) + + self.extra_special_tokens = kwargs.pop("extra_special_tokens", {}) + self._set_model_specific_special_tokens(special_tokens=self.extra_special_tokens) + + @property + def max_len_single_sentence(self) -> int: + """ + `int`: The maximum length of a sentence that can be fed to the model. + """ + return self.model_max_length - self.num_special_tokens_to_add(pair=False) + + @property + def max_len_sentences_pair(self) -> int: + """ + `int`: The maximum combined length of a pair of sentences that can be fed to the model. + """ + return self.model_max_length - self.num_special_tokens_to_add(pair=True) + + @max_len_single_sentence.setter + def max_len_single_sentence(self, value) -> int: + # For backward compatibility, allow to try to setup 'max_len_single_sentence'. + if value == self.model_max_length - self.num_special_tokens_to_add(pair=False) and self.verbose: + if not self.deprecation_warnings.get("max_len_single_sentence", False): + logger.warning( + "Setting 'max_len_single_sentence' is now deprecated. This value is automatically set up." + ) + self.deprecation_warnings["max_len_single_sentence"] = True + else: + raise ValueError( + "Setting 'max_len_single_sentence' is now deprecated. This value is automatically set up." + ) + + @max_len_sentences_pair.setter + def max_len_sentences_pair(self, value) -> int: + # For backward compatibility, allow to try to setup 'max_len_sentences_pair'. + if value == self.model_max_length - self.num_special_tokens_to_add(pair=True) and self.verbose: + if not self.deprecation_warnings.get("max_len_sentences_pair", False): + logger.warning( + "Setting 'max_len_sentences_pair' is now deprecated. This value is automatically set up." + ) + self.deprecation_warnings["max_len_sentences_pair"] = True + else: + raise ValueError("Setting 'max_len_sentences_pair' is now deprecated. This value is automatically set up.") + + def _set_processor_class(self, processor_class: str): + """Sets processor class as an attribute.""" + self._processor_class = processor_class + + @property + def added_tokens_decoder(self) -> dict[int, AddedToken]: + raise NotImplementedError() + + def __repr__(self) -> str: + added_tokens_decoder_rep = "\n\t".join([f"{k}: {v.__repr__()}," for k, v in self.added_tokens_decoder.items()]) + return ( + f"{self.__class__.__name__}(name_or_path='{self.name_or_path}'," + f" vocab_size={self.vocab_size}, model_max_length={self.model_max_length}, is_fast={self.is_fast}," + f" padding_side='{self.padding_side}', truncation_side='{self.truncation_side}'," + f" special_tokens={self.special_tokens_map}, clean_up_tokenization_spaces={self.clean_up_tokenization_spaces}," + " added_tokens_decoder={\n\t" + added_tokens_decoder_rep + "\n}\n)" + ) + + def __len__(self) -> int: + raise NotImplementedError() + + def get_vocab(self) -> dict[str, int]: + """ + Returns the vocabulary as a dictionary of token to index. + + `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the + vocab. + + Returns: + `dict[str, int]`: The vocabulary. + """ + raise NotImplementedError() + + def apply_chat_template( + self, + conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]], + tools: Optional[list[Union[dict, Callable]]] = None, + documents: Optional[list[dict[str, str]]] = None, + chat_template: Optional[str] = None, + add_generation_prompt: bool = False, + continue_final_message: bool = False, + tokenize: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: bool = False, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_dict: bool = False, + return_assistant_tokens_mask: bool = False, + tokenizer_kwargs: Optional[dict[str, Any]] = None, + **kwargs, + ) -> Union[str, list[int], list[str], list[list[int]], BatchEncoding]: + """ + Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token + ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to + determine the format and control tokens to use when converting. + + Args: + conversation (Union[list[dict[str, str]], list[list[dict[str, str]]]]): A list of dicts + with "role" and "content" keys, representing the chat history so far. + tools (`list[Union[Dict, Callable]]`, *optional*): + A list of tools (callable functions) that will be accessible to the model. If the template does not + support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, + giving the name, description and argument types for the tool. See our + [tool use guide](https://huggingface.co/docs/transformers/en/chat_extras#passing-tools) + for more information. + documents (`list[dict[str, str]]`, *optional*): + A list of dicts representing documents that will be accessible to the model if it is performing RAG + (retrieval-augmented generation). If the template does not support RAG, this argument will have no + effect. We recommend that each document should be a dict containing "title" and "text" keys. + chat_template (`str`, *optional*): + A Jinja template to use for this conversion. It is usually not necessary to pass anything to this + argument, as the model's template will be used by default. + add_generation_prompt (bool, *optional*): + If this is set, a prompt with the token(s) that indicate + the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model. + Note that this argument will be passed to the chat template, and so it must be supported in the + template for this argument to have any effect. + continue_final_message (bool, *optional*): + If this is set, the chat will be formatted so that the final + message in the chat is open-ended, without any EOS tokens. The model will continue this message + rather than starting a new one. This allows you to "prefill" part of + the model's response for it. Cannot be used at the same time as `add_generation_prompt`. + tokenize (`bool`, defaults to `True`): + Whether to tokenize the output. If `False`, the output will be a string. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, defaults to `False`): + Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`. + max_length (`int`, *optional*): + Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If + not specified, the tokenizer's `max_length` attribute will be used as a default. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable + values are: + - `'tf'`: Return TensorFlow `tf.Tensor` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + return_dict (`bool`, defaults to `False`): + Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. + tokenizer_kwargs (`dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer. + return_assistant_tokens_mask (`bool`, defaults to `False`): + Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant, + the mask will contain 1. For user and system tokens, the mask will contain 0. + This functionality is only available for chat templates that support it via the `{% generation %}` keyword. + **kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template. + + Returns: + `Union[list[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This + output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is + set, will return a dict of tokenizer outputs instead. + """ + + if return_dict and not tokenize: + raise ValueError( + "`return_dict=True` is incompatible with `tokenize=False`, because there is no dict " + "of tokenizer outputs to return." + ) + + if return_assistant_tokens_mask and not return_dict: + raise ValueError("`return_assistant_tokens_mask=True` is incompatible with `return_dict=False`") + + if tokenizer_kwargs is None: + tokenizer_kwargs = {} + + chat_template = self.get_chat_template(chat_template, tools) + + if isinstance(conversation, (list, tuple)) and ( + isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages") + ): + conversations = conversation + is_batched = True + else: + conversations = [conversation] + is_batched = False + + if continue_final_message: + if add_generation_prompt: + raise ValueError( + "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead." + ) + if return_assistant_tokens_mask: + raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.") + + template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present + rendered_chat, generation_indices = render_jinja_template( + conversations=conversations, + tools=tools, + documents=documents, + chat_template=chat_template, + return_assistant_tokens_mask=return_assistant_tokens_mask, + continue_final_message=continue_final_message, + add_generation_prompt=add_generation_prompt, + **template_kwargs, + ) + + if not is_batched: + rendered_chat = rendered_chat[0] + + if tokenize: + out = self( + rendered_chat, + padding=padding, + truncation=truncation, + max_length=max_length, + add_special_tokens=False, + return_tensors=return_tensors, + **tokenizer_kwargs, + ) + if return_dict: + if return_assistant_tokens_mask: + assistant_masks = [] + if is_batched or return_tensors: + input_ids = out["input_ids"] + else: + input_ids = [out["input_ids"]] + for i in range(len(input_ids)): + current_mask = [0] * len(input_ids[i]) + for assistant_start_char, assistant_end_char in generation_indices[i]: + start_token = out.char_to_token(i, assistant_start_char) + end_token = out.char_to_token(i, assistant_end_char - 1) + if start_token is None: + # start_token is out of bounds maybe due to truncation. + break + for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])): + current_mask[token_id] = 1 + assistant_masks.append(current_mask) + + if not is_batched and not return_tensors: + assistant_masks = assistant_masks[0] + + out["assistant_masks"] = assistant_masks + + if return_tensors: + out.convert_to_tensors(tensor_type=return_tensors) + + return out + else: + return out["input_ids"] + else: + return rendered_chat + + def encode_message_with_chat_template( + self, + message: dict[str, str], + conversation_history: Optional[list[dict[str, str]]] = None, + **kwargs, + ) -> list[int]: + """ + Tokenize a single message. This method is a convenience wrapper around `apply_chat_template` that allows you + to tokenize messages one by one. This is useful for things like token-by-token streaming. + This method is not guaranteed to be perfect. For some models, it may be impossible to robustly tokenize + single messages. For example, if the chat template adds tokens after each message, but also has a prefix that + is added to the entire chat, it will be impossible to distinguish a chat-start-token from a message-start-token. + In these cases, this method will do its best to find the correct tokenization, but it may not be perfect. + **Note:** This method does not support `add_generation_prompt`. If you want to add a generation prompt, + you should do it separately after tokenizing the conversation. + Args: + message (`dict`): + A dictionary with "role" and "content" keys, representing the message to tokenize. + conversation_history (`list[dict]`, *optional*): + A list of dicts with "role" and "content" keys, representing the chat history so far. If you are + tokenizing messages one by one, you should pass the previous messages in the conversation here. + **kwargs: + Additional kwargs to pass to the `apply_chat_template` method. + Returns: + `list[int]`: A list of token ids representing the tokenized message. + """ + if "add_generation_prompt" in kwargs: + raise ValueError( + "`encode_message_with_chat_template` does not support `add_generation_prompt`. Please add the generation prompt " + "separately." + ) + + if conversation_history is None or len(conversation_history) == 0: + return self.apply_chat_template([message], add_generation_prompt=False, tokenize=True, **kwargs) + + conversation = conversation_history + [message] + tokens = self.apply_chat_template(conversation, add_generation_prompt=False, tokenize=True, **kwargs) + + prefix_tokens = self.apply_chat_template( + conversation_history, add_generation_prompt=False, tokenize=True, **kwargs + ) + # It's possible that the prefix tokens are not a prefix of the full list of tokens. + # For example, if the prefix is `User: Hi` and the full conversation is `User: HiAssistant: Hello`. + # In this case, we can't simply find the prefix, so we have to do something a bit more subtle. + # We look for the first place where the tokens differ, and that's our split point. + # This is not perfect, but it's the best we can do without a token-level API. + # To make this more robust, we could do a diff and find the longest common subsequence, but this is + # a good first approximation. + # This is particularly important for models like Llama3 that have changed their chat template to include + # EOS tokens after user messages. + min_len = min(len(prefix_tokens), len(tokens)) + for i in range(min_len): + if prefix_tokens[i] != tokens[i]: + return tokens[i:] + return tokens[min_len:] + + def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[list[dict]] = None) -> str: + """ + Retrieve the chat template string used for tokenizing chat messages. This template is used + internally by the `apply_chat_template` method and can also be used externally to retrieve the model's chat + template for better generation tracking. + + Args: + chat_template (`str`, *optional*): + A Jinja template or the name of a template to use for this conversion. + It is usually not necessary to pass anything to this argument, + as the model's template will be used by default. + tools (`list[Dict]`, *optional*): + A list of tools (callable functions) that will be accessible to the model. If the template does not + support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, + giving the name, description and argument types for the tool. See our + [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) + for more information. + + Returns: + `str`: The chat template string. + """ + # First, handle the cases when the model has a dict of multiple templates + if isinstance(self.chat_template, dict): + template_dict = self.chat_template + if chat_template is not None and chat_template in template_dict: + # The user can pass the name of a template to the chat template argument instead of an entire template + chat_template = template_dict[chat_template] + elif chat_template is None: + if tools is not None and "tool_use" in template_dict: + chat_template = template_dict["tool_use"] + elif "default" in template_dict: + chat_template = template_dict["default"] + else: + raise ValueError( + "This model has multiple chat templates with no default specified! Please either pass a chat " + "template or the name of the template you wish to use to the `chat_template` argument. Available " + f"template names are {sorted(template_dict.keys())}." + ) + + elif chat_template is None: + # These are the cases when the model has a single template + # priority: `chat_template` argument > `tokenizer.chat_template` + if self.chat_template is not None: + chat_template = self.chat_template + else: + raise ValueError( + "Cannot use chat template functions because tokenizer.chat_template is not set and no template " + "argument was passed! For information about writing templates and setting the " + "tokenizer.chat_template attribute, please see the documentation at " + "https://huggingface.co/docs/transformers/main/en/chat_templating" + ) + + return chat_template + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + *init_inputs, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + trust_remote_code=False, + **kwargs, + ): + r""" + Instantiate a [`~tokenization_utils_base.PreTrainedTokenizerBase`] (or a derived class) from a predefined + tokenizer. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co. + - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved + using the [`~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`] method, e.g., + `./my_model_directory/`. + - (**Deprecated**, not applicable to all derived classes) A path or url to a single saved vocabulary + file (if and only if the tokenizer only requires a single vocabulary file like Bert or XLNet), e.g., + `./my_model_directory/vocab.txt`. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download the vocabulary files and override the cached versions if they + exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `hf auth login` (stored in `~/.huggingface`). + local_files_only (`bool`, *optional*, defaults to `False`): + Whether or not to only rely on local files and not to attempt to download any files. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for + facebook/rag-token-base), specify it here. + inputs (additional positional arguments, *optional*): + Will be passed along to the Tokenizer `__init__` method. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. + kwargs (additional keyword arguments, *optional*): + Will be passed to the Tokenizer `__init__` method. Can be used to set special tokens like `bos_token`, + `eos_token`, `unk_token`, `sep_token`, `pad_token`, `cls_token`, `mask_token`, + `additional_special_tokens`. See parameters in the `__init__` for more details. + + + + Passing `token=True` is required when you want to use a private model. + + + + Examples: + + ```python + # We can't instantiate directly the base class *PreTrainedTokenizerBase* so let's show our examples on a derived class: BertTokenizer + # Download vocabulary from huggingface.co and cache. + tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased") + + # Download vocabulary from huggingface.co (user-uploaded) and cache. + tokenizer = BertTokenizer.from_pretrained("dbmdz/bert-base-german-cased") + + # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*) + tokenizer = BertTokenizer.from_pretrained("./test/saved_model/") + + # If the tokenizer uses a single vocabulary file, you can point directly to this file + tokenizer = BertTokenizer.from_pretrained("./test/saved_model/my_vocab.txt") + + # You can link tokens to special vocabulary when instantiating + tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased", unk_token="") + # You should be sure '' is in the vocabulary when doing that. + # Otherwise use tokenizer.add_special_tokens({'unk_token': ''}) instead) + assert tokenizer.unk_token == "" + ```""" + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + use_auth_token = kwargs.pop("use_auth_token", None) + subfolder = kwargs.pop("subfolder", None) + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + commit_hash = kwargs.pop("_commit_hash", None) + gguf_file = kwargs.get("gguf_file") + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + user_agent = {"file_type": "tokenizer", "from_auto_class": from_auto_class, "is_fast": "Fast" in cls.__name__} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + vocab_files = {} + init_configuration = {} + + is_local = os.path.isdir(pretrained_model_name_or_path) + single_file_id = None + if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): + if len(cls.vocab_files_names) > 1 and not gguf_file: + raise ValueError( + f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not " + "supported for this tokenizer. Use a model identifier or the path to a directory instead." + ) + warnings.warn( + f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is deprecated and " + "won't be possible anymore in v5. Use a model identifier or the path to a directory instead.", + FutureWarning, + ) + file_id = list(cls.vocab_files_names.keys())[0] + + vocab_files[file_id] = pretrained_model_name_or_path + single_file_id = file_id + else: + if gguf_file: + vocab_files["vocab_file"] = gguf_file + else: + # At this point pretrained_model_name_or_path is either a directory or a model identifier name + additional_files_names = { + "added_tokens_file": ADDED_TOKENS_FILE, # kept only for legacy + "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, # kept only for legacy + "tokenizer_config_file": TOKENIZER_CONFIG_FILE, + # tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders + "tokenizer_file": FULL_TOKENIZER_FILE, + "chat_template_file": CHAT_TEMPLATE_FILE, + } + + vocab_files = {**cls.vocab_files_names, **additional_files_names} + if "tokenizer_file" in vocab_files: + # Try to get the tokenizer config to see if there are versioned tokenizer files. + fast_tokenizer_file = FULL_TOKENIZER_FILE + + try: + resolved_config_file = cached_file( + pretrained_model_name_or_path, + TOKENIZER_CONFIG_FILE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + subfolder=subfolder, + user_agent=user_agent, + _raise_exceptions_for_missing_entries=False, + _commit_hash=commit_hash, + ) + except OSError: + # Re-raise any error raised by cached_file in order to get a helpful error message + raise + except Exception: + # For any other exception, we throw a generic error. + raise OSError( + f"Can't load tokenizer for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing all relevant files for a {cls.__name__} tokenizer." + ) + + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + if resolved_config_file is not None: + with open(resolved_config_file, encoding="utf-8") as reader: + tokenizer_config = json.load(reader) + if "fast_tokenizer_files" in tokenizer_config: + fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"]) + vocab_files["tokenizer_file"] = fast_tokenizer_file + + # This block looks for any extra chat template files + if is_local: + template_dir = Path(pretrained_model_name_or_path, CHAT_TEMPLATE_DIR) + if template_dir.is_dir(): + for template_file in template_dir.glob("*.jinja"): + template_name = template_file.name.removesuffix(".jinja") + vocab_files[f"chat_template_{template_name}"] = ( + f"{CHAT_TEMPLATE_DIR}/{template_file.name}" + ) + else: + for template in list_repo_templates( + pretrained_model_name_or_path, + local_files_only=local_files_only, + revision=revision, + cache_dir=cache_dir, + token=token, + ): + template = template.removesuffix(".jinja") + vocab_files[f"chat_template_{template}"] = f"{CHAT_TEMPLATE_DIR}/{template}.jinja" + + remote_files = [] + if not is_local and not local_files_only: + try: + remote_files = list_repo_files(pretrained_model_name_or_path) + except Exception: + remote_files = [] + elif pretrained_model_name_or_path and os.path.isdir(pretrained_model_name_or_path): + remote_files = os.listdir(pretrained_model_name_or_path) + + if "tokenizer_file" in vocab_files and not re.search(vocab_files["tokenizer_file"], "".join(remote_files)): + # mistral tokenizer names are different, but we can still convert them if + # mistral common is not there + other_pattern = r"tekken\.json|tokenizer\.model\.*" + if match := re.search(other_pattern, "\n".join(remote_files)): + vocab_files["vocab_file"] = match.group() + + resolved_vocab_files = {} + for file_id, file_path in vocab_files.items(): + if file_path is None: + resolved_vocab_files[file_id] = None + elif single_file_id == file_id: + if os.path.isfile(file_path): + resolved_vocab_files[file_id] = file_path + elif is_remote_url(file_path): + resolved_vocab_files[file_id] = download_url(file_path, proxies=proxies) + else: + try: + resolved_vocab_files[file_id] = cached_file( + pretrained_model_name_or_path, + file_path, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + _commit_hash=commit_hash, + ) + except OSError: + # Re-raise any error raised by cached_file in order to get a helpful error message + raise + except Exception: + # For any other exception, we throw a generic error. + raise OSError( + f"Can't load tokenizer for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing all relevant files for a {cls.__name__} tokenizer." + ) + commit_hash = extract_commit_hash(resolved_vocab_files[file_id], commit_hash) + + for file_id, file_path in vocab_files.items(): + if file_id not in resolved_vocab_files: + continue + + if is_local: + logger.info(f"loading file {file_path}") + else: + logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}") + + return cls._from_pretrained( + resolved_vocab_files, + pretrained_model_name_or_path, + init_configuration, + *init_inputs, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + _commit_hash=commit_hash, + _is_local=is_local, + trust_remote_code=trust_remote_code, + **kwargs, + ) + + @classmethod + def _from_pretrained( + cls, + resolved_vocab_files, + pretrained_model_name_or_path, + init_configuration, + *init_inputs, + token=None, + cache_dir=None, + local_files_only=False, + _commit_hash=None, + _is_local=False, + trust_remote_code=False, + **kwargs, + ): + # We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json + # file or if `from_slow` is set to True. + from_slow = kwargs.get("from_slow", False) + gguf_file = kwargs.get("gguf_file") + has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None + + # If one passes a GGUF file path to `gguf_file` there is no need for this check as the tokenizer will be + # loaded directly from the GGUF file. + if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None and not gguf_file: + slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained( + copy.deepcopy(resolved_vocab_files), + pretrained_model_name_or_path, + copy.deepcopy(init_configuration), + *init_inputs, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + _commit_hash=_commit_hash, + **(copy.deepcopy(kwargs)), + ) + else: + slow_tokenizer = None + + # Prepare tokenizer initialization kwargs + # Did we saved some inputs and kwargs to reload ? + tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None) + if tokenizer_config_file is not None: + with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: + init_kwargs = json.load(tokenizer_config_handle) + # First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers. + config_tokenizer_class = init_kwargs.get("tokenizer_class") + init_kwargs.pop("tokenizer_class", None) + if not has_tokenizer_file: + init_kwargs.pop("tokenizer_file", None) + saved_init_inputs = init_kwargs.pop("init_inputs", ()) + if not init_inputs: + init_inputs = saved_init_inputs + else: + config_tokenizer_class = None + init_kwargs = init_configuration + + # If independent chat template file(s) exist, they take priority over template entries in the tokenizer config + chat_templates = {} + chat_template_file = resolved_vocab_files.pop("chat_template_file", None) + extra_chat_templates = [key for key in resolved_vocab_files if key.startswith("chat_template_")] + if chat_template_file is not None: + with open(chat_template_file, encoding="utf-8") as chat_template_handle: + chat_templates["default"] = chat_template_handle.read() + for extra_chat_template in extra_chat_templates: + template_file = resolved_vocab_files.pop(extra_chat_template, None) + if template_file is None: + continue # I think this should never happen, but just in case + template_name = extra_chat_template.removeprefix("chat_template_") + with open(template_file, encoding="utf8") as chat_template_handle: + chat_templates[template_name] = chat_template_handle.read() + if len(chat_templates) == 1 and "default" in chat_templates: + init_kwargs["chat_template"] = chat_templates["default"] + elif chat_templates: + init_kwargs["chat_template"] = chat_templates + + if not _is_local: + if "auto_map" in init_kwargs: + # For backward compatibility with odl format. + if isinstance(init_kwargs["auto_map"], (tuple, list)): + init_kwargs["auto_map"] = {"AutoTokenizer": init_kwargs["auto_map"]} + + if config_tokenizer_class is None: + # Matt: This entire block is only used to decide if the tokenizer class matches the class in the repo. + # If not, it raises a warning, but otherwise continues. Since we mostly load tokenizers with + # AutoTokenizer these days, it seems like a lot of work (and a source of bugs) for little gain. + # Maybe we can just remove this entirely? + from .models.auto.configuration_auto import AutoConfig # tests_ignore + + # Second attempt. If we have not yet found tokenizer_class, let's try to use the config. + try: + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + trust_remote_code=trust_remote_code, + _commit_hash=_commit_hash, + ) + config_tokenizer_class = config.tokenizer_class + except (OSError, ValueError, KeyError): + # skip if an error occurred. + config = None + if config_tokenizer_class is None: + # Third attempt. If we have not yet found the original type of the tokenizer, + # we are loading we see if we can infer it from the type of the configuration file + from .models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES # tests_ignore + + if hasattr(config, "model_type"): + model_type = config.model_type + else: + # Fallback: use pattern matching on the string. + model_type = None + for pattern in TOKENIZER_MAPPING_NAMES: + if pattern in str(pretrained_model_name_or_path): + model_type = pattern + break + + if model_type is not None: + config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING_NAMES.get( + model_type, (None, None) + ) + if config_tokenizer_class is None: + config_tokenizer_class = config_tokenizer_class_fast + + if config_tokenizer_class is not None: + if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""): + logger.warning( + "The tokenizer class you load from this checkpoint is not the same type as the class this" + " function is called from. It may result in unexpected tokenization. \nThe tokenizer class you" + f" load from this checkpoint is '{config_tokenizer_class}'. \nThe class this function is called" + f" from is '{cls.__name__}'." + ) + + # Update with newly provided kwargs + init_kwargs.update(kwargs) + + # Merge resolved_vocab_files arguments in init_kwargs. + added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None) + special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None) + for args_name, file_path in resolved_vocab_files.items(): + if args_name not in init_kwargs: + init_kwargs[args_name] = file_path + tokenizer_file = resolved_vocab_files.pop("tokenizer_file", None) + + if slow_tokenizer is not None: + init_kwargs["__slow_tokenizer"] = slow_tokenizer + init_kwargs["name_or_path"] = pretrained_model_name_or_path + + #### Handle tokenizer serialization of added and special tokens + added_tokens_decoder: dict[int, AddedToken] = {} + added_tokens_map: dict[str, AddedToken] = {} + # if we have info on the slow added tokens + if "added_tokens_decoder" in init_kwargs: + for idx, token in init_kwargs["added_tokens_decoder"].items(): + if isinstance(token, dict): + token = AddedToken(**token) + if isinstance(token, AddedToken): + added_tokens_decoder[int(idx)] = token + added_tokens_map[str(token)] = token + else: + raise TypeError( + f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary or an AddedToken instance" + ) + else: + # begin legacy: read the added_tokens_file and update kwargs with special_tokens_map if modified + if special_tokens_map_file is not None: + with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle: + special_tokens_map = json.load(special_tokens_map_handle) + for key, value in special_tokens_map.items(): + if key in kwargs and kwargs[key]: + # This value has already been redefined by the kwargs + # We keep this new value and ignore the one stored in the special_tokens_map_file + continue + if isinstance(value, dict): + value["special"] = True + value = AddedToken(**value) + elif key == "additional_special_tokens" and isinstance(value, list): + additional_special_tokens = init_kwargs.pop("additional_special_tokens", []) or [] + for token in value: + if isinstance(token, dict): + token["special"] = True + token = AddedToken(**token) + if token not in additional_special_tokens: + additional_special_tokens.append(token) + value = additional_special_tokens + init_kwargs[key] = value + + # slow -> slow|fast, legacy: convert the `"added_tokens.json"` file to `added_tokens_decoder`. + # this is for legacy purpose. We don't add the tokens after init for efficiency. + if added_tokens_file is not None: + special_tokens = [] + for key in cls.SPECIAL_TOKENS_ATTRIBUTES & init_kwargs.keys(): + if init_kwargs[key] is not None: + if key == "additional_special_tokens": + special_tokens += [str(token) for token in init_kwargs[key]] + else: + special_tokens.append(str(init_kwargs[key])) + + with open(added_tokens_file, encoding="utf-8") as added_tokens_handle: + added_tok_encoder = json.load(added_tokens_handle) + for str_token, index in added_tok_encoder.items(): + # if index not in added_tokens_decoder and str_token not in added_tokens_map: + special = str_token in special_tokens + added_tokens_decoder[index] = AddedToken( + str_token, rstrip=False, lstrip=False, normalized=not special, special=special + ) + added_tokens_map[str(token)] = added_tokens_decoder[index] + + # allows converting a fast -> slow: add the `tokenizer.json`'s `"added_tokens"` to the slow tokenizer + # if `tokenizer_config.json` is `None` + if tokenizer_file is not None: + # This is for slow so can be done before + with open(tokenizer_file, encoding="utf-8") as tokenizer_file_handle: + tokenizer_file_handle = json.load(tokenizer_file_handle) + added_tokens = tokenizer_file_handle.pop("added_tokens") + for serialized_tokens in added_tokens: + idx = serialized_tokens.pop("id") + added_tokens_decoder[idx] = AddedToken(**serialized_tokens) + added_tokens_map[str(added_tokens_decoder[idx])] = added_tokens_decoder[idx] + # end legacy + + # Passing AddedTokens and not strings to the class to prevent it from casting the string to a different AddedToken + # convert {'__type': 'AddedToken', 'content': '', 'lstrip': False, 'normalized': True, ...} to AddedTokens + init_kwargs["added_tokens_decoder"] = added_tokens_decoder + init_kwargs = cls.convert_added_tokens(init_kwargs, save=False) + for key in cls.SPECIAL_TOKENS_ATTRIBUTES & init_kwargs.keys(): + if added_tokens_map != {} and init_kwargs[key] is not None: + if key != "additional_special_tokens": + init_kwargs[key] = added_tokens_map.get(str(init_kwargs[key]), init_kwargs[key]) + + # Instantiate the tokenizer. + try: + tokenizer = cls(*init_inputs, **init_kwargs) + except import_protobuf_decode_error(): + logger.info( + "Unable to load tokenizer model from SPM, loading from TikToken will be attempted instead." + "(Google protobuf error: Tried to load SPM model with non-SPM vocab file).", + ) + return False + except RuntimeError as e: + if "sentencepiece_processor.cc" in str(e): + logger.info( + "Unable to load tokenizer model from SPM, loading from TikToken will be attempted instead." + "(SentencePiece RuntimeError: Tried to load SPM model with non-SPM vocab file).", + ) + return False + except OSError: + raise OSError( + "Unable to load vocabulary from file. " + "Please check that the provided vocabulary is accessible and not corrupted." + ) + + if added_tokens_decoder != {} and max(list(added_tokens_decoder.keys())[-1], 0) > tokenizer.vocab_size: + logger.info( + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are" + " fine-tuned or trained." + ) + try: + vocab_size = tokenizer.vocab_size + except NotImplementedError: + vocab_size = 0 + + # Optionally patches mistral tokenizers with wrong regex + if ( + vocab_size > 100000 + and hasattr(tokenizer, "_tokenizer") + and getattr(tokenizer._tokenizer, "pre_tokenizer", None) is not None + ): + tokenizer = cls._patch_mistral_regex( + tokenizer, + pretrained_model_name_or_path, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + _commit_hash=_commit_hash, + _is_local=_is_local, + init_kwargs=init_kwargs, + fix_mistral_regex=kwargs.get("fix_mistral_regex"), + ) + + return tokenizer + + @classmethod + def _patch_mistral_regex( + cls, + tokenizer, + pretrained_model_name_or_path, + token=None, + cache_dir=None, + local_files_only=False, + _commit_hash=None, + _is_local=False, + init_kwargs=None, + fix_mistral_regex=None, + ): + """ + Patches mistral related tokenizers with incorrect regex if detected + 1) Local file with an associated config saved next to it + >> Model type one of the mistral models (on older versions) + 2) Remote models on the hub from official mistral models + >> Tags including `base_model:.*mistralai` + """ + from huggingface_hub import model_info + + def is_base_mistral(model_id: str) -> bool: + model = model_info(model_id) + if model.tags is not None: + if re.search("base_model:.*mistralai", "".join(model.tags)): + return True + return False + + if is_offline_mode(): + _is_local = True + + if pretrained_model_name_or_path is not None and ( + _is_local or (not _is_local and is_base_mistral(pretrained_model_name_or_path)) + ): + _config_file = cached_file( + pretrained_model_name_or_path, + "config.json", + cache_dir=cache_dir, + token=token, + local_files_only=local_files_only, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + _commit_hash=_commit_hash, + ) + + # Detected using a (local) mistral tokenizer + mistral_config_detected = False + if _config_file is not None: + with open(_config_file, encoding="utf-8") as f: + _config = json.load(f) + transformers_version = _config.get("transformers_version") + transformers_model_type = _config.get("model_type") + + # Detect if we can skip the mistral fix by + # a) having a non-mistral tokenizer + # b) fixed version of transformers + if transformers_version and version.parse(transformers_version) <= version.parse("4.57.2"): + if ( + _is_local + and transformers_model_type is not None + and transformers_model_type + not in [ + "mistral", + "mistral3", + "voxtral", + "ministral", + "pixtral", + ] + ): + return tokenizer + elif transformers_version and version.parse(transformers_version) >= version.parse("5.0.0"): + return tokenizer + + mistral_config_detected = True + + if mistral_config_detected or (not _is_local and is_base_mistral(pretrained_model_name_or_path)): + # Expose the `fix_mistral_regex` flag on the tokenizer when provided, even if no correction is applied. + if init_kwargs and "fix_mistral_regex" in init_kwargs: + setattr(tokenizer, "fix_mistral_regex", init_kwargs["fix_mistral_regex"]) + + # only warn if its not explicitly passed + if fix_mistral_regex is None and not getattr(tokenizer, "fix_mistral_regex", False): + setattr(tokenizer, "fix_mistral_regex", False) + logger.warning( + f"The tokenizer you are loading from '{pretrained_model_name_or_path}'" + f" with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e." + " This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue." + ) + elif fix_mistral_regex is True or getattr(tokenizer, "fix_mistral_regex", False): + setattr(tokenizer, "fix_mistral_regex", True) + import tokenizers + + tokenizer.backend_tokenizer.pre_tokenizer[0] = tokenizers.pre_tokenizers.Split( + pattern=tokenizers.Regex( + r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+" + ), + behavior="isolated", + ) + return tokenizer + + @staticmethod + def _eventually_correct_t5_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length): + # This method should be deleted in Transformers v5 + # Its only purpose is to potentially throw a warning + # that incorrectly defined max lengths of T5's tokenizer are used + # which we will correct in Transformers v5. + return max_model_length + + @classmethod + def convert_added_tokens(cls, obj: Union[AddedToken, Any], save=False, add_type_field=True): + if isinstance(obj, dict) and "__type" in obj and obj["__type"] == "AddedToken": + obj.pop("__type") + return AddedToken(**obj) + if isinstance(obj, AddedToken) and save: + obj = obj.__getstate__() + if add_type_field: + obj["__type"] = "AddedToken" + else: + # Don't save "special" for previous tokenizers + obj.pop("special") + return obj + elif isinstance(obj, (list, tuple)): + return [cls.convert_added_tokens(o, save=save, add_type_field=add_type_field) for o in obj] + elif isinstance(obj, dict): + return {k: cls.convert_added_tokens(v, save=save, add_type_field=add_type_field) for k, v in obj.items()} + return obj + + def save_chat_templates( + self, + save_directory: Union[str, os.PathLike], + tokenizer_config: dict, + filename_prefix: Optional[str], + save_jinja_files: bool, + ): + """ + Writes chat templates out to the save directory if we're using the new format, and removes them from + the tokenizer config if present. If we're using the legacy format, it doesn't write any files, and instead + writes the templates to the tokenizer config in the correct format. + """ + chat_template_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_FILE + ) + chat_template_dir = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_DIR + ) + + saved_raw_chat_template_files = [] + if save_jinja_files and isinstance(self.chat_template, str): + # New format for single templates is to save them as chat_template.jinja + with open(chat_template_file, "w", encoding="utf-8") as f: + f.write(self.chat_template) + logger.info(f"chat template saved in {chat_template_file}") + saved_raw_chat_template_files.append(chat_template_file) + if "chat_template" in tokenizer_config: + tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too + elif save_jinja_files and isinstance(self.chat_template, dict): + # New format for multiple templates is to save the default as chat_template.jinja + # and the other templates in the chat_templates/ directory + for template_name, template in self.chat_template.items(): + if template_name == "default": + with open(chat_template_file, "w", encoding="utf-8") as f: + f.write(self.chat_template["default"]) + logger.info(f"chat template saved in {chat_template_file}") + saved_raw_chat_template_files.append(chat_template_file) + else: + Path(chat_template_dir).mkdir(exist_ok=True) + template_filepath = os.path.join(chat_template_dir, f"{template_name}.jinja") + with open(template_filepath, "w", encoding="utf-8") as f: + f.write(template) + logger.info(f"chat template saved in {template_filepath}") + saved_raw_chat_template_files.append(template_filepath) + if "chat_template" in tokenizer_config: + tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too + elif isinstance(self.chat_template, dict): + # Legacy format for multiple templates: + # chat template dicts are saved to the config as lists of dicts with fixed key names. + tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()] + elif self.chat_template is not None: + # Legacy format for single templates: Just make them a key in tokenizer_config.json + tokenizer_config["chat_template"] = self.chat_template + return tokenizer_config, saved_raw_chat_template_files + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + legacy_format: Optional[bool] = None, + filename_prefix: Optional[str] = None, + push_to_hub: bool = False, + **kwargs, + ) -> tuple[str, ...]: + """ + Save the full tokenizer state. + + + This method make sure the full tokenizer can then be re-loaded using the + [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`] class method.. + + Warning,None This won't save modifications you may have applied to the tokenizer after the instantiation (for + instance, modifying `tokenizer.do_lower_case` after creation). + + Args: + save_directory (`str` or `os.PathLike`): The path to a directory where the tokenizer will be saved. + legacy_format (`bool`, *optional*): + Only applicable for a fast tokenizer. If unset (default), will save the tokenizer in the unified JSON + format as well as in legacy format if it exists, i.e. with tokenizer specific vocabulary and a separate + added_tokens files. + + If `False`, will only save the tokenizer in the unified JSON format. This format is incompatible with + "slow" tokenizers (not powered by the *tokenizers* library), so the tokenizer will not be able to be + loaded in the corresponding "slow" tokenizer. + + If `True`, will save the tokenizer in legacy format. If the "slow" tokenizer doesn't exits, a value + error is raised. + filename_prefix (`str`, *optional*): + A prefix to add to the names of the files saved by the tokenizer. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + + Returns: + A tuple of `str`: The files saved. + """ + use_auth_token = kwargs.pop("use_auth_token", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if kwargs.get("token") is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + kwargs["token"] = use_auth_token + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + + special_tokens_map_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + SPECIAL_TOKENS_MAP_FILE + ) + tokenizer_config_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_CONFIG_FILE + ) + + tokenizer_config = copy.deepcopy(self.init_kwargs) + + # Let's save the init kwargs + target_keys = set(self.init_kwargs.keys()) + # Let's save the special tokens map (only the strings) + target_keys.update(["model_max_length", "clean_up_tokenization_spaces"]) + + for k in target_keys: + if hasattr(self, k): + tokenizer_config[k] = getattr(self, k) + + # Let's make sure we properly save the special tokens + tokenizer_config.update(self.special_tokens_map) + if "extra_special_tokens" not in tokenizer_config: + tokenizer_config["extra_special_tokens"] = self.extra_special_tokens + tokenizer_config.update(self.extra_special_tokens) + + save_jinja_files = kwargs.get("save_jinja_files", True) + tokenizer_config, saved_raw_chat_template_files = self.save_chat_templates( + save_directory, tokenizer_config, filename_prefix, save_jinja_files + ) + + if len(self.init_inputs) > 0: + tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs) + for file_id in self.vocab_files_names: + tokenizer_config.pop(file_id, None) + + # no typefields, this way old fast and slow can load it + tokenizer_config = self.convert_added_tokens(tokenizer_config, add_type_field=True, save=True) + + # Process added tokens separately: allows previous versions to ignore it! + added_tokens = {} + for key, value in self.added_tokens_decoder.items(): + added_tokens[key] = value.__getstate__() + tokenizer_config["added_tokens_decoder"] = added_tokens + + # Add tokenizer class to the tokenizer config to be able to reload it with from_pretrained + tokenizer_class = self.__class__.__name__ + # Remove the Fast at the end if we can save the slow tokenizer + if tokenizer_class.endswith("Fast") and getattr(self, "can_save_slow_tokenizer", False): + tokenizer_class = tokenizer_class[:-4] + tokenizer_config["tokenizer_class"] = tokenizer_class + if getattr(self, "_auto_map", None) is not None: + tokenizer_config["auto_map"] = self._auto_map + if getattr(self, "_processor_class", None) is not None: + tokenizer_config["processor_class"] = self._processor_class + + # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + custom_object_save(self, save_directory, config=tokenizer_config) + + # remove private information + if "name_or_path" in tokenizer_config: + tokenizer_config.pop("name_or_path") + tokenizer_config.pop("special_tokens_map_file", None) + tokenizer_config.pop("tokenizer_file", None) + if "device_map" in tokenizer_config: + tokenizer_config.pop("device_map") + + with open(tokenizer_config_file, "w", encoding="utf-8") as f: + out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n" + f.write(out_str) + logger.info(f"tokenizer config file saved in {tokenizer_config_file}") + + # Sanitize AddedTokens in special_tokens_map + + # kept for forward compatibility, will be removed in transoformers 5. Typefields are not saved for FC, special should not be save either + write_dict = self.convert_added_tokens(self.special_tokens_map_extended, save=True, add_type_field=False) + with open(special_tokens_map_file, "w", encoding="utf-8") as f: + out_str = json.dumps(write_dict, indent=2, sort_keys=True, ensure_ascii=False) + "\n" + f.write(out_str) + logger.info(f"Special tokens file saved in {special_tokens_map_file}") + + file_names = (tokenizer_config_file, special_tokens_map_file, *saved_raw_chat_template_files) + + save_files = self._save_pretrained( + save_directory=save_directory, + file_names=file_names, + legacy_format=legacy_format, + filename_prefix=filename_prefix, + ) + + if push_to_hub: + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("token"), + ) + + return save_files + + def _save_pretrained( + self, + save_directory: Union[str, os.PathLike], + file_names: tuple[str, ...], + legacy_format: Optional[bool] = None, + filename_prefix: Optional[str] = None, + ) -> tuple[str, ...]: + """ + Save a tokenizer using the slow-tokenizer/legacy format: vocabulary + added tokens. + + Fast tokenizers can also be saved in a unique JSON file containing {config + vocab + added-tokens} using the + specific [`~tokenization_utils_fast.PreTrainedTokenizerFast._save_pretrained`] + """ + if legacy_format is False: + raise ValueError( + "Only fast tokenizers (instances of PreTrainedTokenizerFast) can be saved in non legacy format." + ) + + save_directory = str(save_directory) + + added_tokens_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE + ) + # the new get_added_vocab() also returns special tokens and tokens that have an index < vocab_size + added_vocab = {tok: index for tok, index in self.added_tokens_encoder.items() if index >= self.vocab_size} + if added_vocab: + with open(added_tokens_file, "w", encoding="utf-8") as f: + out_str = json.dumps(added_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n" + f.write(out_str) + logger.info(f"added tokens file saved in {added_tokens_file}") + + vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix) + + return file_names + vocab_files + (added_tokens_file,) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str, ...]: + """ + Save only the vocabulary of the tokenizer (vocabulary + added tokens). + + This method won't save the configuration and special token mappings of the tokenizer. Use + [`~PreTrainedTokenizerFast._save_pretrained`] to save the whole state of the tokenizer. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + filename_prefix (`str`, *optional*): + An optional prefix to add to the named of the saved files. + + Returns: + `tuple(str)`: Paths to the files saved. + """ + raise NotImplementedError + + def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> list[str]: + """ + Converts a string into a sequence of tokens, replacing unknown tokens with the `unk_token`. + + Args: + text (`str`): + The sequence to be encoded. + pair (`str`, *optional*): + A second sequence to be encoded with the first. + add_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add the special tokens associated with the corresponding model. + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific encode method. See details in + [`~PreTrainedTokenizerBase.__call__`] + + Returns: + `list[str]`: The list of tokens. + """ + raise NotImplementedError + + @add_end_docstrings( + ENCODE_KWARGS_DOCSTRING, + """ + **kwargs: Passed along to the `.tokenize()` method. + """, + """ + Returns: + `list[int]`, `torch.Tensor`, `tf.Tensor` or `np.ndarray`: The tokenized ids of the text. + """, + ) + def encode( + self, + text: Union[TextInput, PreTokenizedInput, EncodedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy, None] = None, + max_length: Optional[int] = None, + stride: int = 0, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> list[int]: + """ + Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary. + + Same as doing `self.convert_tokens_to_ids(self.tokenize(text))`. + + Args: + text (`str`, `list[str]` or `list[int]`): + The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the + `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + text_pair (`str`, `list[str]` or `list[int]`, *optional*): + Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using + the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + """ + encoded_inputs = self.encode_plus( + text, + text_pair=text_pair, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + padding_side=padding_side, + return_tensors=return_tensors, + **kwargs, + ) + + return encoded_inputs["input_ids"] + + def num_special_tokens_to_add(self, pair: bool = False) -> int: + raise NotImplementedError + + def _get_padding_truncation_strategies( + self, padding=False, truncation=None, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs + ): + """ + Find the correct padding/truncation strategy + """ + + # Backward compatibility for previous behavior, maybe we should deprecate it: + # If you only set max_length, it activates truncation for max_length + if max_length is not None and padding is False and truncation is None: + if verbose: + if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False): + logger.warning( + "Truncation was not explicitly activated but `max_length` is provided a specific value, please" + " use `truncation=True` to explicitly truncate examples to max length. Defaulting to" + " 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the" + " tokenizer you can select this strategy more precisely by providing a specific strategy to" + " `truncation`." + ) + self.deprecation_warnings["Truncation-not-explicitly-activated"] = True + truncation = "longest_first" + + # Get padding strategy + if padding is not False: + if padding is True: + if verbose: + if max_length is not None and ( + truncation is None or truncation is False or truncation == "do_not_truncate" + ): + warnings.warn( + "`max_length` is ignored when `padding`=`True` and there is no truncation strategy. " + "To pad to max length, use `padding='max_length'`." + ) + padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch + elif not isinstance(padding, PaddingStrategy): + padding_strategy = PaddingStrategy(padding) + elif isinstance(padding, PaddingStrategy): + padding_strategy = padding + else: + padding_strategy = PaddingStrategy.DO_NOT_PAD + + # Get truncation strategy + if truncation is not False and truncation is not None: + if truncation is True: + truncation_strategy = ( + TruncationStrategy.LONGEST_FIRST + ) # Default to truncate the longest sequences in pairs of inputs + elif not isinstance(truncation, TruncationStrategy): + truncation_strategy = TruncationStrategy(truncation) + elif isinstance(truncation, TruncationStrategy): + truncation_strategy = truncation + else: + truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE + + # Set max length if needed + if max_length is None: + if padding_strategy == PaddingStrategy.MAX_LENGTH: + if self.model_max_length > LARGE_INTEGER: + if verbose: + if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False): + logger.warning( + "Asking to pad to max_length but no maximum length is provided and the model has no" + " predefined maximum length. Default to no padding." + ) + self.deprecation_warnings["Asking-to-pad-to-max_length"] = True + padding_strategy = PaddingStrategy.DO_NOT_PAD + else: + max_length = self.model_max_length + + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE: + if self.model_max_length > LARGE_INTEGER: + if verbose: + if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False): + logger.warning( + "Asking to truncate to max_length but no maximum length is provided and the model has" + " no predefined maximum length. Default to no truncation." + ) + self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True + truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE + else: + max_length = self.model_max_length + + # Test if we have a padding token + if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.pad_token is None or self.pad_token_id < 0): + raise ValueError( + "Asking to pad but the tokenizer does not have a padding token. " + "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` " + "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`." + ) + + # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided + if ( + truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE + and padding_strategy != PaddingStrategy.DO_NOT_PAD + and pad_to_multiple_of is not None + and max_length is not None + and (max_length % pad_to_multiple_of != 0) + ): + raise ValueError( + "Truncation and padding are both activated but " + f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})." + ) + + return padding_strategy, truncation_strategy, max_length, kwargs + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput], None] = None, + text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None, + text_target: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput], None] = None, + text_pair_target: Optional[ + Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] + ] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy, None] = None, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of + sequences. + + Args: + text (`str`, `list[str]`, `list[list[str]]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_pair (`str`, `list[str]`, `list[list[str]]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_target (`str`, `list[str]`, `list[list[str]]`, *optional*): + The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a + list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), + you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_pair_target (`str`, `list[str]`, `list[list[str]]`, *optional*): + The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a + list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), + you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + """ + # To avoid duplicating + all_kwargs = { + "add_special_tokens": add_special_tokens, + "padding": padding, + "truncation": truncation, + "max_length": max_length, + "stride": stride, + "is_split_into_words": is_split_into_words, + "pad_to_multiple_of": pad_to_multiple_of, + "padding_side": padding_side, + "return_tensors": return_tensors, + "return_token_type_ids": return_token_type_ids, + "return_attention_mask": return_attention_mask, + "return_overflowing_tokens": return_overflowing_tokens, + "return_special_tokens_mask": return_special_tokens_mask, + "return_offsets_mapping": return_offsets_mapping, + "return_length": return_length, + "split_special_tokens": kwargs.pop("split_special_tokens", self.split_special_tokens), + "verbose": verbose, + } + + if return_tensors in ("tf", "jax"): + logger.warning_once( + "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We " + "recommend migrating to PyTorch classes or pinning your version of Transformers." + ) + all_kwargs.update(kwargs) + if text is None and text_target is None: + raise ValueError("You need to specify either `text` or `text_target`.") + if text is not None: + # The context manager will send the inputs as normal texts and not text_target, but we shouldn't change the + # input mode in this case. + if not self._in_target_context_manager: + self._switch_to_input_mode() + encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs) + if text_target is not None: + self._switch_to_target_mode() + target_encodings = self._call_one(text=text_target, text_pair=text_pair_target, **all_kwargs) + # Leave back tokenizer in input mode + self._switch_to_input_mode() + + if text_target is None: + return encodings + elif text is None: + return target_encodings + else: + encodings["labels"] = target_encodings["input_ids"] + return encodings + + def _call_one( + self, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]], + text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy, None] = None, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + split_special_tokens: bool = False, + **kwargs, + ) -> BatchEncoding: + # Input type checking for clearer error + def _is_valid_text_input(t): + if isinstance(t, str): + # Strings are fine + return True + elif isinstance(t, (list, tuple)): + # List are fine as long as they are... + if len(t) == 0: + # ... empty + return True + elif isinstance(t[0], str): + # ... list of strings + return True + elif isinstance(t[0], (list, tuple)): + # ... list with an empty list or with a list of strings + return len(t[0]) == 0 or isinstance(t[0][0], str) + else: + return False + else: + return False + + if not _is_valid_text_input(text): + raise ValueError( + "text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) " + "or `list[list[str]]` (batch of pretokenized examples)." + ) + + if text_pair is not None and not _is_valid_text_input(text_pair): + raise ValueError( + "text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) " + "or `list[list[str]]` (batch of pretokenized examples)." + ) + + if is_split_into_words: + is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)) + else: + is_batched = isinstance(text, (list, tuple)) + + if is_batched: + if isinstance(text_pair, str): + raise TypeError( + "when tokenizing batches of text, `text_pair` must be a list or tuple with the same length as" + " `text`." + ) + if text_pair is not None and len(text) != len(text_pair): + raise ValueError( + f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:" + f" {len(text_pair)}." + ) + batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text + return self.batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + split_special_tokens=split_special_tokens, + **kwargs, + ) + else: + return self.encode_plus( + text=text, + text_pair=text_pair, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + split_special_tokens=split_special_tokens, + **kwargs, + ) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def encode_plus( + self, + text: Union[TextInput, PreTokenizedInput, EncodedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy, None] = None, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Tokenize and prepare for the model a sequence or a pair of sequences. + + + + This method is deprecated, `__call__` should be used instead. + + + + Args: + text (`str`, `list[str]` or (for non-fast tokenizers) `list[int]`): + The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the + `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + text_pair (`str`, `list[str]` or `list[int]`, *optional*): + Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using + the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + """ + + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._encode_plus( + text=text, + text_pair=text_pair, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + split_special_tokens=kwargs.pop("split_special_tokens", self.split_special_tokens), + **kwargs, + ) + + def _encode_plus( + self, + text: Union[TextInput, PreTokenizedInput, EncodedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + split_special_tokens: bool = False, + **kwargs, + ) -> BatchEncoding: + raise NotImplementedError + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + list[TextInput], + list[TextInputPair], + list[PreTokenizedInput], + list[PreTokenizedInputPair], + list[EncodedInput], + list[EncodedInputPair], + ], + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy, None] = None, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + split_special_tokens: bool = False, + **kwargs, + ) -> BatchEncoding: + """ + Tokenize and prepare for the model a list of sequences or a list of pairs of sequences. + + + + This method is deprecated, `__call__` should be used instead. + + + + Args: + batch_text_or_text_pairs (`list[str]`, `list[tuple[str, str]]`, `list[list[str]]`, `list[tuple[list[str], list[str]]]`, and for not-fast tokenizers, also `list[list[int]]`, `list[tuple[list[int], list[int]]]`): + Batch of sequences or pair of sequences to be encoded. This can be a list of + string/string-sequences/int-sequences or a list of pair of string/string-sequences/int-sequence (see + details in `encode_plus`). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + split_special_tokens=split_special_tokens, + **kwargs, + ) + + def _batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + list[TextInput], + list[TextInputPair], + list[PreTokenizedInput], + list[PreTokenizedInputPair], + list[EncodedInput], + list[EncodedInputPair], + ], + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + split_special_tokens: bool = False, + **kwargs, + ) -> BatchEncoding: + raise NotImplementedError + + def pad( + self, + encoded_inputs: Union[ + BatchEncoding, + list[BatchEncoding], + dict[str, EncodedInput], + dict[str, list[EncodedInput]], + list[dict[str, EncodedInput]], + ], + padding: Union[bool, str, PaddingStrategy] = True, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + verbose: bool = True, + ) -> BatchEncoding: + """ + Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length + in the batch. + + Padding side (left/right) padding token ids are defined at the tokenizer level (with `self.padding_side`, + `self.pad_token_id` and `self.pad_token_type_id`). + + Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the + text followed by a call to the `pad` method to get a padded encoding. + + + + If the `encoded_inputs` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the + result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of + PyTorch tensors, you will lose the specific device of your tensors however. + + + + Args: + encoded_inputs ([`BatchEncoding`], list of [`BatchEncoding`], `dict[str, list[int]]`, `dict[str, list[list[int]]` or `list[dict[str, list[int]]]`): + Tokenized inputs. Can represent one input ([`BatchEncoding`] or `dict[str, list[int]]`) or a batch of + tokenized inputs (list of [`BatchEncoding`], *dict[str, list[list[int]]]* or *list[dict[str, + list[int]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader + collate function. + + Instead of `list[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), see + the note above for the return type. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + padding_side (`str`, *optional*): + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are attention masks?](../glossary#attention-mask) + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + verbose (`bool`, *optional*, defaults to `True`): + Whether or not to print more information and warnings. + """ + if self.__class__.__name__.endswith("Fast"): + if not self.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False): + logger.warning_advice( + f"You're using a {self.__class__.__name__} tokenizer. Please note that with a fast tokenizer," + " using the `__call__` method is faster than using a method to encode the text followed by a call" + " to the `pad` method to get a padded encoding." + ) + self.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True + + # If we have a list of dicts, let's convert it in a dict of lists + # We do this to allow using this method as a collate_fn function in PyTorch Dataloader + if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping): + encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0]} + + # The model's main input name, usually `input_ids`, has been passed for padding + if self.model_input_names[0] not in encoded_inputs: + raise ValueError( + "You should supply an encoding or a list of encodings to this method " + f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}" + ) + + required_input = encoded_inputs[self.model_input_names[0]] + + if required_input is None or (isinstance(required_input, Sized) and len(required_input) == 0): + if return_attention_mask: + encoded_inputs["attention_mask"] = [] + return encoded_inputs + + # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects + # and rebuild them afterwards if no return_tensors is specified + # Note that we lose the specific device the tensor may be on for PyTorch + + first_element = required_input[0] + if isinstance(first_element, (list, tuple)): + # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element. + for item in required_input: + if len(item) != 0: + first_element = item[0] + break + # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. + if not isinstance(first_element, (int, list, tuple)): + if is_tf_tensor(first_element): + return_tensors = "tf" if return_tensors is None else return_tensors + elif is_torch_tensor(first_element): + return_tensors = "pt" if return_tensors is None else return_tensors + elif isinstance(first_element, np.ndarray): + return_tensors = "np" if return_tensors is None else return_tensors + else: + raise ValueError( + f"type of {first_element} unknown: {type(first_element)}. " + "Should be one of a python, numpy, pytorch or tensorflow object." + ) + + for key, value in encoded_inputs.items(): + encoded_inputs[key] = to_py_obj(value) + + # Convert padding_strategy in PaddingStrategy + padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies( + padding=padding, max_length=max_length, verbose=verbose + ) + + required_input = encoded_inputs[self.model_input_names[0]] + if required_input and not isinstance(required_input[0], (list, tuple)): + encoded_inputs = self._pad( + encoded_inputs, + max_length=max_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + return BatchEncoding(encoded_inputs, tensor_type=return_tensors) + + batch_size = len(required_input) + assert all(len(v) == batch_size for v in encoded_inputs.values()), ( + "Some items in the output dictionary have a different batch size than others." + ) + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = max(len(inputs) for inputs in required_input) + padding_strategy = PaddingStrategy.MAX_LENGTH + + batch_outputs = {} + for i in range(batch_size): + inputs = {k: v[i] for k, v in encoded_inputs.items()} + outputs = self._pad( + inputs, + max_length=max_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + return BatchEncoding(batch_outputs, tensor_type=return_tensors) + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create the token type IDs corresponding to the sequences passed. [What are token type + IDs?](../glossary#token-type-ids) + + Should be overridden in a subclass if the model has a special way of building those. + + Args: + token_ids_0 (`list[int]`): The first tokenized sequence. + token_ids_1 (`list[int]`, *optional*): The second tokenized sequence. + + Returns: + `list[int]`: The token type ids. + """ + cls_len = int(getattr(self, "cls_token_id", None) is not None) + sep_len = int(getattr(self, "sep_token_id", None) is not None) + + if token_ids_1 is None: + return [0] * (cls_len + len(token_ids_0) + sep_len) + + return [0] * (cls_len + len(token_ids_0) + sep_len) + [1] * (len(token_ids_1) + sep_len) + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. + + This implementation does not add special tokens and this method should be overridden in a subclass. + + Args: + token_ids_0 (`list[int]`): The first tokenized sequence. + token_ids_1 (`list[int]`, *optional*): The second tokenized sequence. + + Returns: + `list[int]`: The model input with special tokens. + """ + if token_ids_1 is None: + return token_ids_0 + return token_ids_0 + token_ids_1 + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def prepare_for_model( + self, + ids: list[int], + pair_ids: Optional[list[int]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy, None] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + prepend_batch_axis: bool = False, + **kwargs, + ) -> BatchEncoding: + """ + Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It + adds special tokens, truncates sequences if overflowing while taking into account the special tokens and + manages a moving window (with user defined stride) for overflowing tokens. Please Note, for *pair_ids* + different than `None` and *truncation_strategy = longest_first* or `True`, it is not possible to return + overflowing tokens. Such a combination of arguments will raise an error. + + Args: + ids (`list[int]`): + Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and + `convert_tokens_to_ids` methods. + pair_ids (`list[int]`, *optional*): + Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize` + and `convert_tokens_to_ids` methods. + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + pair = pair_ids is not None + len_ids = len(ids) + len_pair_ids = len(pair_ids) if pair else 0 + + if return_token_type_ids and not add_special_tokens: + raise ValueError( + "Asking to return token_type_ids while setting add_special_tokens to False " + "results in an undefined behavior. Please set add_special_tokens to True or " + "set return_token_type_ids to None." + ) + + if ( + return_overflowing_tokens + and truncation_strategy == TruncationStrategy.LONGEST_FIRST + and pair_ids is not None + ): + raise ValueError( + "Not possible to return overflowing tokens for pair of sequences with the " + "`longest_first`. Please select another truncation strategy than `longest_first`, " + "for instance `only_second` or `only_first`." + ) + + # Load from model defaults + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + encoded_inputs = {} + + # Compute the total size of the returned encodings + total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) + + # Truncation: Handle max sequence length + overflowing_tokens = [] + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: + ids, pair_ids, overflowing_tokens = self.truncate_sequences( + ids, + pair_ids=pair_ids, + num_tokens_to_remove=total_len - max_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + + if return_overflowing_tokens: + encoded_inputs["overflowing_tokens"] = overflowing_tokens + encoded_inputs["num_truncated_tokens"] = total_len - max_length + + # Add special tokens + if add_special_tokens: + sequence = self.build_inputs_with_special_tokens(ids, pair_ids) + token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) + else: + sequence = ids + pair_ids if pair else ids + token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else []) + + # Build output dictionary + encoded_inputs["input_ids"] = sequence + if return_token_type_ids: + encoded_inputs["token_type_ids"] = token_type_ids + if return_special_tokens_mask: + if add_special_tokens: + encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) + else: + encoded_inputs["special_tokens_mask"] = [0] * len(sequence) + + # Check lengths + self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose) + + # Padding + if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: + encoded_inputs = self.pad( + encoded_inputs, + max_length=max_length, + padding=padding_strategy.value, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) + + if return_length: + encoded_inputs["length"] = len(encoded_inputs["input_ids"]) + + batch_outputs = BatchEncoding( + encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis + ) + + return batch_outputs + + def truncate_sequences( + self, + ids: list[int], + pair_ids: Optional[list[int]] = None, + num_tokens_to_remove: int = 0, + truncation_strategy: Union[str, TruncationStrategy] = "longest_first", + stride: int = 0, + ) -> tuple[list[int], list[int], list[int]]: + """ + Truncates a sequence pair in-place following the strategy. + + Args: + ids (`list[int]`): + Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and + `convert_tokens_to_ids` methods. + pair_ids (`list[int]`, *optional*): + Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize` + and `convert_tokens_to_ids` methods. + num_tokens_to_remove (`int`, *optional*, defaults to 0): + Number of tokens to remove using the truncation strategy. + truncation_strategy (`str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `'longest_first'`): + The strategy to follow for truncation. Can be: + + - `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will truncate + token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a + batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater + than the model maximum admissible input size). + stride (`int`, *optional*, defaults to 0): + If set to a positive number, the overflowing tokens returned will contain some tokens from the main + sequence returned. The value of this argument defines the number of additional tokens. + + Returns: + `tuple[list[int], list[int], list[int]]`: The truncated `ids`, the truncated `pair_ids` and the list of + overflowing tokens. Note: The *longest_first* strategy returns empty list of overflowing tokens if a pair + of sequences (or a batch of pairs) is provided. + """ + if num_tokens_to_remove <= 0: + return ids, pair_ids, [] + + if not isinstance(truncation_strategy, TruncationStrategy): + truncation_strategy = TruncationStrategy(truncation_strategy) + + overflowing_tokens = [] + if truncation_strategy == TruncationStrategy.ONLY_FIRST or ( + truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None + ): + if len(ids) > num_tokens_to_remove: + window_len = min(len(ids), stride + num_tokens_to_remove) + if self.truncation_side == "left": + overflowing_tokens = ids[:window_len] + ids = ids[num_tokens_to_remove:] + elif self.truncation_side == "right": + overflowing_tokens = ids[-window_len:] + ids = ids[:-num_tokens_to_remove] + else: + raise ValueError(f"invalid truncation strategy: {self.truncation_side}, use 'left' or 'right'.") + + else: + error_msg = ( + f"We need to remove {num_tokens_to_remove} to truncate the input " + f"but the first sequence has a length {len(ids)}. " + ) + if truncation_strategy == TruncationStrategy.ONLY_FIRST: + error_msg = ( + error_msg + "Please select another truncation strategy than " + f"{truncation_strategy}, for instance 'longest_first' or 'only_second'." + ) + logger.error(error_msg) + elif truncation_strategy == TruncationStrategy.LONGEST_FIRST: + logger.warning( + "Be aware, overflowing tokens are not returned for the setting you have chosen," + f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' " + "truncation strategy. So the returned list will always be empty even if some " + "tokens have been removed." + ) + len_pair_ids = len(pair_ids) if pair_ids is not None else 0 + len_ids = len(ids) + first_remove = min(abs(len_pair_ids - len_ids), num_tokens_to_remove) + second_remove = num_tokens_to_remove - first_remove + if len_ids > len_pair_ids: + ids_to_move = first_remove + second_remove // 2 + pair_ids_to_move = second_remove - second_remove // 2 + else: + ids_to_move = second_remove // 2 + pair_ids_to_move = first_remove + second_remove - (second_remove // 2) + + if self.truncation_side == "right": + ids = ids[:-ids_to_move] if ids_to_move > 0 else ids + pair_ids = pair_ids[:-pair_ids_to_move] if pair_ids is not None and pair_ids_to_move > 0 else pair_ids + elif self.truncation_side == "left": + ids = ids[ids_to_move:] + pair_ids = pair_ids[pair_ids_to_move:] if pair_ids is not None else None + else: + raise ValueError(f"invalid truncation strategy:{self.truncation_side}") + + elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None: + if len(pair_ids) > num_tokens_to_remove: + window_len = min(len(pair_ids), stride + num_tokens_to_remove) + if self.truncation_side == "right": + overflowing_tokens = pair_ids[-window_len:] + pair_ids = pair_ids[:-num_tokens_to_remove] + elif self.truncation_side == "left": + overflowing_tokens = pair_ids[:window_len] + pair_ids = pair_ids[num_tokens_to_remove:] + else: + raise ValueError(f"invalid truncation strategy:{self.truncation_side}") + else: + logger.error( + f"We need to remove {num_tokens_to_remove} to truncate the input " + f"but the second sequence has a length {len(pair_ids)}. " + f"Please select another truncation strategy than {truncation_strategy}, " + "for instance 'longest_first' or 'only_first'." + ) + + return (ids, pair_ids, overflowing_tokens) + + def _pad( + self, + encoded_inputs: Union[dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`list[int]`) or batch of tokenized inputs (`list[list[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in `padding_side` argument: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + padding_side: + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + required_input = encoded_inputs[self.model_input_names[0]] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if return_attention_mask and "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * len(required_input) + + if needs_to_be_padded: + difference = max_length - len(required_input) + padding_side = padding_side if padding_side is not None else self.padding_side + + if padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = ( + encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference + ) + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference + encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference + elif padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ + "token_type_ids" + ] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + else: + raise ValueError(f"Invalid padding strategy:{padding_side}") + + return encoded_inputs + + def convert_tokens_to_string(self, tokens: list[str]) -> str: + """ + Converts a sequence of tokens in a single string. The most simple way to do it is `" ".join(tokens)` but we + often want to remove sub-word tokenization artifacts at the same time. + + Args: + tokens (`list[str]`): The token to join in a string. + + Returns: + `str`: The joined tokens. + """ + raise NotImplementedError + + def batch_decode( + self, + sequences: Union[list[int], list[list[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = None, + **kwargs, + ) -> list[str]: + """ + Convert a list of lists of token ids into a list of strings by calling decode. + + Args: + sequences (`Union[list[int], list[list[int]], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Can be obtained using the `__call__` method. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + clean_up_tokenization_spaces (`bool`, *optional*): + Whether or not to clean up the tokenization spaces. If `None`, will default to + `self.clean_up_tokenization_spaces`. + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. + + Returns: + `list[str]`: The list of decoded sentences. + """ + return [ + self.decode( + seq, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + for seq in sequences + ] + + def decode( + self, + token_ids: Union[int, list[int], np.ndarray, "torch.Tensor"], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = None, + **kwargs, + ) -> str: + """ + Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special + tokens and clean up tokenization spaces. + + Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. + + Args: + token_ids (`Union[int, list[int], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Can be obtained using the `__call__` method. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + clean_up_tokenization_spaces (`bool`, *optional*): + Whether or not to clean up the tokenization spaces. If `None`, will default to + `self.clean_up_tokenization_spaces`. + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. + + Returns: + `str`: The decoded sentence. + """ + # Convert inputs to python lists + token_ids = to_py_obj(token_ids) + + return self._decode( + token_ids=token_ids, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + def _decode( + self, + token_ids: Union[int, list[int]], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = None, + **kwargs, + ) -> str: + raise NotImplementedError + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. + + Args: + token_ids_0 (`list[int]`): + List of ids of the first sequence. + token_ids_1 (`list[int]`, *optional*): + List of ids of the second sequence. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + assert already_has_special_tokens and token_ids_1 is None, ( + "You cannot use ``already_has_special_tokens=False`` with this tokenizer. " + "Please use a slow (full python) tokenizer to activate this argument. " + "Or set `return_special_tokens_mask=True` when calling the encoding method " + "to get the special tokens mask in any tokenizer. " + ) + + all_special_ids = self.all_special_ids # cache the property + + special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0] + + return special_tokens_mask + + @staticmethod + def clean_up_tokenization(out_string: str) -> str: + """ + Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms. + + Args: + out_string (`str`): The text to clean up. + + Returns: + `str`: The cleaned-up string. + """ + out_string = ( + out_string.replace(" .", ".") + .replace(" ?", "?") + .replace(" !", "!") + .replace(" ,", ",") + .replace(" ' ", "'") + .replace(" n't", "n't") + .replace(" 'm", "'m") + .replace(" 's", "'s") + .replace(" 've", "'ve") + .replace(" 're", "'re") + ) + return out_string + + def _eventual_warn_about_too_long_sequence(self, ids: list[int], max_length: Optional[int], verbose: bool): + """ + Depending on the input and internal state we might trigger a warning about a sequence that is too long for its + corresponding model + + Args: + ids (`list[str]`): The ids produced by the tokenization + max_length (`int`, *optional*): The max_length desired (does not trigger a warning if it is set) + verbose (`bool`): Whether or not to print more information and warnings. + + """ + if max_length is None and len(ids) > self.model_max_length and verbose and self.model_max_length != 0: + if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False): + logger.warning( + "Token indices sequence length is longer than the specified maximum sequence length " + f"for this model ({len(ids)} > {self.model_max_length}). Running this sequence through the model " + "will result in indexing errors" + ) + self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True + + def _switch_to_input_mode(self): + """ + Private method to put the tokenizer in input mode (when it has different modes for input/outputs) + """ + pass + + def _switch_to_target_mode(self): + """ + Private method to put the tokenizer in target mode (when it has different modes for input/outputs) + """ + pass + + @contextmanager + def as_target_tokenizer(self): + """ + Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to + sequence-to-sequence models that need a slightly different processing for the labels. + """ + warnings.warn( + "`as_target_tokenizer` is deprecated and will be removed in v5 of Transformers. You can tokenize your " + "labels by using the argument `text_target` of the regular `__call__` method (either in the same call as " + "your input texts if you use the same keyword arguments, or in a separate call." + ) + self._switch_to_target_mode() + self._in_target_context_manager = True + yield + self._in_target_context_manager = False + self._switch_to_input_mode() + + @classmethod + def register_for_auto_class(cls, auto_class="AutoTokenizer"): + """ + Register this class with a given auto class. This should only be used for custom tokenizers as the ones in the + library are already mapped with `AutoTokenizer`. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"AutoTokenizer"`): + The auto class to register this new tokenizer with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class + + def prepare_seq2seq_batch( + self, + src_texts: list[str], + tgt_texts: Optional[list[str]] = None, + max_length: Optional[int] = None, + max_target_length: Optional[int] = None, + padding: str = "longest", + return_tensors: Optional[str] = None, + truncation: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Prepare model inputs for translation. For best performance, translate one sentence at a time. + + Arguments: + src_texts (`list[str]`): + List of documents to summarize or source language texts. + tgt_texts (`list`, *optional*): + List of summaries or target language texts. + max_length (`int`, *optional*): + Controls the maximum length for encoder inputs (documents to summarize or source language texts) If + left unset or set to `None`, this will use the predefined model maximum length if a maximum length is + required by one of the truncation/padding parameters. If the model has no specific maximum input length + (like XLNet) truncation/padding to a maximum length will be deactivated. + max_target_length (`int`, *optional*): + Controls the maximum length of decoder inputs (target language texts or summaries) If left unset or set + to `None`, this will use the max_length value. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `True`): + Activates and controls truncation. Accepts the following values: + + - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will + truncate token by token, removing a token from the longest sequence in the pair if a pair of + sequences (or a batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + **kwargs: + Additional keyword arguments passed along to `self.__call__`. + + Return: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + + - **input_ids** -- List of token ids to be fed to the encoder. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + - **labels** -- List of token ids for tgt_texts. + + The full set of keys `[input_ids, attention_mask, labels]`, will only be returned if tgt_texts is passed. + Otherwise, input_ids, attention_mask will be the only keys. + """ + # docstyle-ignore + formatted_warning = """ +`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of HuggingFace Transformers. Use the regular +`__call__` method to prepare your inputs and targets. + +Here is a short example: + +model_inputs = tokenizer(src_texts, text_target=tgt_texts, ...) + +If you either need to use different keyword arguments for the source and target texts, you should do two calls like +this: + +model_inputs = tokenizer(src_texts, ...) +labels = tokenizer(text_target=tgt_texts, ...) +model_inputs["labels"] = labels["input_ids"] + +See the documentation of your specific tokenizer for more details on the specific arguments to the tokenizer of choice. +For a more complete example, see the implementation of `prepare_seq2seq_batch`. +""" + warnings.warn(formatted_warning, FutureWarning) + # mBART-specific kwargs that should be ignored by other models. + kwargs.pop("src_lang", None) + kwargs.pop("tgt_lang", None) + if max_length is None: + max_length = self.model_max_length + model_inputs = self( + src_texts, + add_special_tokens=True, + return_tensors=return_tensors, + max_length=max_length, + padding=padding, + truncation=truncation, + **kwargs, + ) + if tgt_texts is None: + return model_inputs + # Process tgt_texts + if max_target_length is None: + max_target_length = max_length + with self.as_target_tokenizer(): + labels = self( + tgt_texts, + add_special_tokens=True, + return_tensors=return_tensors, + padding=padding, + max_length=max_target_length, + truncation=truncation, + **kwargs, + ) + model_inputs["labels"] = labels["input_ids"] + return model_inputs + + +def get_fast_tokenizer_file(tokenization_files: list[str]) -> str: + """ + Get the tokenization file to use for this version of transformers. + + Args: + tokenization_files (`list[str]`): The list of available configuration files. + + Returns: + `str`: The tokenization file to use. + """ + tokenizer_files_map = {} + for file_name in tokenization_files: + search = _re_tokenizer_file.search(file_name) + if search is not None: + v = search.groups()[0] + tokenizer_files_map[v] = file_name + available_versions = sorted(tokenizer_files_map.keys()) + + # Defaults to FULL_TOKENIZER_FILE and then try to look at some newer versions. + tokenizer_file = FULL_TOKENIZER_FILE + transformers_version = version.parse(__version__) + for v in available_versions: + if version.parse(v) <= transformers_version: + tokenizer_file = tokenizer_files_map[v] + else: + # No point going further since the versions are sorted. + break + + return tokenizer_file + + +# To update the docstring, we need to copy the method, otherwise we change the original docstring. +PreTrainedTokenizerBase.push_to_hub = copy_func(PreTrainedTokenizerBase.push_to_hub) +if PreTrainedTokenizerBase.push_to_hub.__doc__ is not None: + PreTrainedTokenizerBase.push_to_hub.__doc__ = PreTrainedTokenizerBase.push_to_hub.__doc__.format( + object="tokenizer", object_class="AutoTokenizer", object_files="tokenizer files" + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tokenization_utils_fast.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tokenization_utils_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..fe4873d61b378410c90db80b4b19094f3c6ab292 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/tokenization_utils_fast.py @@ -0,0 +1,922 @@ +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tokenization classes for fast tokenizers (provided by HuggingFace's tokenizers library). For slow (python) tokenizers +see tokenization_utils.py +""" + +import copy +import json +import os +from collections import defaultdict +from collections.abc import Iterable +from typing import Any, Optional, Union + +import tokenizers.pre_tokenizers as pre_tokenizers_fast +from tokenizers import Encoding as EncodingFast +from tokenizers import Tokenizer as TokenizerFast +from tokenizers.decoders import Decoder as DecoderFast +from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordLevelTrainer, WordPieceTrainer + +from .convert_slow_tokenizer import convert_slow_tokenizer +from .integrations.ggml import convert_gguf_tokenizer +from .modeling_gguf_pytorch_utils import load_gguf_checkpoint +from .tokenization_utils import PreTrainedTokenizer +from .tokenization_utils_base import ( + INIT_TOKENIZER_DOCSTRING, + AddedToken, + BatchEncoding, + PreTokenizedInput, + PreTokenizedInputPair, + PreTrainedTokenizerBase, + SpecialTokensMixin, + TextInput, + TextInputPair, + TruncationStrategy, +) +from .utils import PaddingStrategy, add_end_docstrings, logging + + +logger = logging.get_logger(__name__) + +# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file +TOKENIZER_FILE = "tokenizer.json" +SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" +TOKENIZER_CONFIG_FILE = "tokenizer_config.json" +TIKTOKEN_VOCAB_FILE = "tokenizer.model" + +# Slow tokenizers have an additional added tokens files +ADDED_TOKENS_FILE = "added_tokens.json" + +INIT_TOKENIZER_DOCSTRING += """ + tokenizer_object ([`tokenizers.Tokenizer`]): + A [`tokenizers.Tokenizer`] object from 🤗 tokenizers to instantiate from. See [Using tokenizers from 🤗 + tokenizers](../fast_tokenizers) for more information. + tokenizer_file ([`str`]): + A path to a local JSON file representing a previously serialized [`tokenizers.Tokenizer`] object from 🤗 + tokenizers. +""" + +MODEL_TO_TRAINER_MAPPING = { + "BPE": BpeTrainer, + "Unigram": UnigramTrainer, + "WordLevel": WordLevelTrainer, + "WordPiece": WordPieceTrainer, +} + +VOCAB_FILES_NAMES = {"tokenizer_file": TOKENIZER_FILE, "vocab_file": TIKTOKEN_VOCAB_FILE} + + +@add_end_docstrings(INIT_TOKENIZER_DOCSTRING) +class PreTrainedTokenizerFast(PreTrainedTokenizerBase): + """ + Base class for all fast tokenizers (wrapping HuggingFace tokenizers library). + + Inherits from [`~tokenization_utils_base.PreTrainedTokenizerBase`]. + + Handles all the shared methods for tokenization and special tokens, as well as methods for + downloading/caching/loading pretrained tokenizers, as well as adding tokens to the vocabulary. + + This class also contains the added tokens in a unified way on top of all tokenizers so we don't have to handle the + specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...). + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class: Optional[type[PreTrainedTokenizer]] = None + + def __init__(self, *args, **kwargs): + tokenizer_object = kwargs.pop("tokenizer_object", None) + slow_tokenizer = kwargs.pop("__slow_tokenizer", None) + gguf_file = kwargs.pop("gguf_file", None) + fast_tokenizer_file = kwargs.pop("tokenizer_file", None) + from_slow = kwargs.pop("from_slow", False) + added_tokens_decoder = kwargs.pop("added_tokens_decoder", {}) + self.add_prefix_space = kwargs.get("add_prefix_space", False) + + if from_slow and slow_tokenizer is None and self.slow_tokenizer_class is None: + raise ValueError( + "Cannot instantiate this tokenizer from a slow version. If it's based on sentencepiece, make sure you " + "have sentencepiece installed." + ) + + if tokenizer_object is not None: + fast_tokenizer = copy.deepcopy(tokenizer_object) + elif fast_tokenizer_file is not None and not from_slow: + # We have a serialization from tokenizers which let us directly build the backend + fast_tokenizer = TokenizerFast.from_file(fast_tokenizer_file) + elif slow_tokenizer: + # We need to convert a slow tokenizer to build the backend + fast_tokenizer = convert_slow_tokenizer(slow_tokenizer) + elif gguf_file is not None: + # We need to convert a slow tokenizer to build the backend + gguf_param = load_gguf_checkpoint(kwargs.get("vocab_file")) + architecture = gguf_param["config"]["model_type"] + tokenizer_dict = gguf_param["tokenizer"] + tokenizer_config = gguf_param["tokenizer_config"] + fast_tokenizer, additional_kwargs = convert_gguf_tokenizer(architecture, tokenizer_dict) + kwargs.update(tokenizer_config) + if len(additional_kwargs) > 0: + kwargs.update(additional_kwargs) + elif self.slow_tokenizer_class is not None and slow_tokenizer is not False: + # We need to create and convert a slow tokenizer to build the backend + slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs) + fast_tokenizer = convert_slow_tokenizer(slow_tokenizer) + elif not slow_tokenizer: + # We tried loading a slow_tokenizer with spm and failed, try to load with tiktoken + self.vocab_file = kwargs.get("vocab_file") + self.additional_special_tokens = kwargs.get("additional_special_tokens", []) + fast_tokenizer = convert_slow_tokenizer(self, from_tiktoken=True) + slow_tokenizer = None + else: + raise ValueError( + "Couldn't instantiate the backend tokenizer from one of: \n" + "(1) a `tokenizers` library serialization file, \n" + "(2) a slow tokenizer instance to convert or \n" + "(3) an equivalent slow tokenizer class to instantiate and convert. \n" + "You need to have sentencepiece or tiktoken installed to convert a slow tokenizer to a fast one." + ) + + self._tokenizer = fast_tokenizer + + if slow_tokenizer is not None: + kwargs.update(slow_tokenizer.init_kwargs) + + self._decode_use_source_tokenizer = False + + _truncation = self._tokenizer.truncation + + if _truncation is not None: + self._tokenizer.enable_truncation(**_truncation) + kwargs.setdefault("max_length", _truncation["max_length"]) + kwargs.setdefault("truncation_side", _truncation["direction"]) + kwargs.setdefault("stride", _truncation["stride"]) + kwargs.setdefault("truncation_strategy", _truncation["strategy"]) + else: + self._tokenizer.no_truncation() + + _padding = self._tokenizer.padding + if _padding is not None: + self._tokenizer.enable_padding(**_padding) + kwargs.setdefault("pad_token", _padding["pad_token"]) + kwargs.setdefault("pad_token_type_id", _padding["pad_type_id"]) + kwargs.setdefault("padding_side", _padding["direction"]) + kwargs.setdefault("max_length", _padding["length"]) + kwargs.setdefault("pad_to_multiple_of", _padding["pad_to_multiple_of"]) + + # We call this after having initialized the backend tokenizer because we update it. + super().__init__(**kwargs) + self._tokenizer.encode_special_tokens = self.split_special_tokens + + added_tokens_decoder_hash = {hash(repr(token)) for token in self.added_tokens_decoder} + tokens_to_add = [ + token + for index, token in sorted(added_tokens_decoder.items(), key=lambda x: x[0]) + if hash(repr(token)) not in added_tokens_decoder_hash + ] + encoder = list(self.added_tokens_encoder.keys()) + [str(token) for token in tokens_to_add] + # if some of the special tokens are strings, we check if we don't already have a token + tokens_to_add += [ + token for token in self.all_special_tokens_extended if token not in encoder and token not in tokens_to_add + ] + + if len(tokens_to_add) > 0: + tokens = [] + special_tokens = self.all_special_tokens + for token in tokens_to_add: + is_special = ( + (token.special or str(token) in special_tokens) + if isinstance(token, AddedToken) + else str(token) in special_tokens + ) + if isinstance(token, str): + token = AddedToken(token, special=is_special) + else: + token.special = is_special + tokens.append(token) + if tokens: + self.add_tokens(tokens) + + try: + pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__()) + if pre_tok_state.get("add_prefix_space", self.add_prefix_space) != self.add_prefix_space: + pre_tok_class = getattr(pre_tokenizers_fast, pre_tok_state.pop("type")) + pre_tok_state["add_prefix_space"] = self.add_prefix_space + self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state) + except Exception: + # We'll get an error if there is no pre_tokenizer, or if it's a custom pre_tokenizer that can + # not be serialized. In those cases, we just ignore the error as there's no pre_tokenizer + # for which we need to update the `add_prefix_space` attribute. + pass + + @property + def is_fast(self) -> bool: + return True + + @property + def can_save_slow_tokenizer(self) -> bool: + """ + `bool`: Whether or not the slow tokenizer can be saved. For a sentencepiece based slow tokenizer, this + can only be `True` if the original `"sentencepiece.model"` was not deleted. + """ + if "vocab_file" in self.vocab_files_names and self.vocab_files_names["vocab_file"].endswith(".model"): + if hasattr(self, "vocab_file") and self.vocab_file: + # If the vocab file is a sentencepiece model, we can save it + return os.path.isfile(self.vocab_file) + return False + else: + return True + + @property + def vocab_size(self) -> int: + """ + `int`: Size of the base vocabulary (without the added tokens). + """ + return self._tokenizer.get_vocab_size(with_added_tokens=False) + + def get_vocab(self) -> dict[str, int]: + return self._tokenizer.get_vocab(with_added_tokens=True) + + @property + def vocab(self) -> dict[str, int]: + return self.get_vocab() + + @property + def added_tokens_encoder(self) -> dict[str, int]: + """ + Returns the sorted mapping from string to index. The added tokens encoder is cached for performance + optimisation in `self._added_tokens_encoder` for the slow tokenizers. + """ + return {k.content: v for v, k in sorted(self.added_tokens_decoder.items(), key=lambda item: item[0])} + + @property + def added_tokens_decoder(self) -> dict[int, AddedToken]: + """ + Returns the added tokens in the vocabulary as a dictionary of index to AddedToken. + + Returns: + `dict[str, int]`: The added tokens. + """ + return self._tokenizer.get_added_tokens_decoder() + + def get_added_vocab(self) -> dict[str, int]: + """ + Returns the added tokens in the vocabulary as a dictionary of token to index. + + Returns: + `dict[str, int]`: The added tokens. + """ + return {k.content: v for v, k in sorted(self.added_tokens_decoder.items(), key=lambda item: item[0])} + + def __bool__(self) -> bool: + """ + Returns True, to avoid expensive `assert tokenizer` gotchas. + """ + return True + + def __len__(self) -> int: + """ + Size of the full vocabulary with the added tokens. + """ + return self._tokenizer.get_vocab_size(with_added_tokens=True) + + @property + def backend_tokenizer(self) -> TokenizerFast: + """ + `tokenizers.implementations.BaseTokenizer`: The Rust tokenizer used as a backend. + """ + return self._tokenizer + + @property + def decoder(self) -> DecoderFast: + """ + `tokenizers.decoders.Decoder`: The Rust decoder for this tokenizer. + """ + return self._tokenizer.decoder + + def _convert_encoding( + self, + encoding: EncodingFast, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + ) -> tuple[dict[str, Any], list[EncodingFast]]: + """ + Convert the encoding representation (from low-level HuggingFace tokenizer output) to a python Dict and a list + of encodings, take care of building a batch from overflowing tokens. + + Overflowing tokens are converted to additional examples (like batches) so the output values of the dict are + lists (overflows) of lists (tokens). + + Output shape: (overflows, sequence length) + """ + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + if return_overflowing_tokens and encoding.overflowing is not None: + encodings = [encoding] + encoding.overflowing + else: + encodings = [encoding] + + encoding_dict = defaultdict(list) + for e in encodings: + encoding_dict["input_ids"].append(e.ids) + + if return_token_type_ids: + encoding_dict["token_type_ids"].append(e.type_ids) + if return_attention_mask: + encoding_dict["attention_mask"].append(e.attention_mask) + if return_special_tokens_mask: + encoding_dict["special_tokens_mask"].append(e.special_tokens_mask) + if return_offsets_mapping: + encoding_dict["offset_mapping"].append(e.offsets) + if return_length: + encoding_dict["length"].append(len(e.ids)) + + return encoding_dict, encodings + + def convert_tokens_to_ids(self, tokens: Union[str, Iterable[str]]) -> Union[int, list[int]]: + """ + Converts a token string (or a sequence of tokens) in a single integer id (or a Iterable of ids), using the + vocabulary. + + Args: + tokens (`str` or `Iterable[str]`): One or several token(s) to convert to token id(s). + + Returns: + `int` or `list[int]`: The token id or list of token ids. + """ + if isinstance(tokens, str): + return self._convert_token_to_id_with_added_voc(tokens) + + return [self._convert_token_to_id_with_added_voc(token) for token in tokens] + + def _convert_token_to_id_with_added_voc(self, token: str) -> int: + index = self._tokenizer.token_to_id(token) + if index is None: + return self.unk_token_id + return index + + def _convert_id_to_token(self, index: int) -> Optional[str]: + return self._tokenizer.id_to_token(int(index)) + + def _add_tokens(self, new_tokens: list[Union[str, AddedToken]], special_tokens=False) -> int: + if special_tokens: + return self._tokenizer.add_special_tokens(new_tokens) + + return self._tokenizer.add_tokens(new_tokens) + + def num_special_tokens_to_add(self, pair: bool = False) -> int: + """ + Returns the number of added tokens when encoding a sequence with special tokens. + + + + This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not put + this inside your training loop. + + + + Args: + pair (`bool`, *optional*, defaults to `False`): + Whether the number of added tokens should be computed in the case of a sequence pair or a single + sequence. + + Returns: + `int`: Number of special tokens added to sequences. + """ + return self._tokenizer.num_special_tokens_to_add(pair) + + def convert_ids_to_tokens( + self, ids: Union[int, list[int]], skip_special_tokens: bool = False + ) -> Union[str, list[str]]: + """ + Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and + added tokens. + + Args: + ids (`int` or `list[int]`): + The token id (or token ids) to convert to tokens. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + + Returns: + `str` or `list[str]`: The decoded token(s). + """ + if isinstance(ids, int): + return self._tokenizer.id_to_token(ids) + tokens = [] + # self.all_special_ids is an @property which may be slow, so only compute it once before the loop + ids_to_skip = set(self.all_special_ids) if skip_special_tokens else set() + for index in ids: + index = int(index) + if index in ids_to_skip: + continue + tokens.append(self._tokenizer.id_to_token(index)) + return tokens + + def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> list[str]: + return self.encode_plus(text=text, text_pair=pair, add_special_tokens=add_special_tokens, **kwargs).tokens() + + def set_truncation_and_padding( + self, + padding_strategy: PaddingStrategy, + truncation_strategy: TruncationStrategy, + max_length: int, + stride: int, + pad_to_multiple_of: Optional[int], + padding_side: Optional[str], + ): + """ + Define the truncation and the padding strategies for fast tokenizers (provided by HuggingFace tokenizers + library) and restore the tokenizer settings afterwards. + + The provided tokenizer has no padding / truncation strategy before the managed section. If your tokenizer set a + padding / truncation strategy before, then it will be reset to no padding / truncation when exiting the managed + section. + + Args: + padding_strategy ([`~utils.PaddingStrategy`]): + The kind of padding that will be applied to the input + truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`]): + The kind of truncation that will be applied to the input + max_length (`int`): + The maximum size of a sequence. + stride (`int`): + The stride to use when handling overflow. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta). + padding_side (`str`, *optional*): + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + """ + _truncation = self._tokenizer.truncation + _padding = self._tokenizer.padding + # Set truncation and padding on the backend tokenizer + if truncation_strategy == TruncationStrategy.DO_NOT_TRUNCATE: + if _truncation is not None: + self._tokenizer.no_truncation() + else: + target = { + "max_length": max_length, + "stride": stride, + "strategy": truncation_strategy.value, + "direction": self.truncation_side, + } + + # _truncation might contain more keys that the target `transformers` + # supports. Use only the target keys to trigger `enable_truncation`. + # This should enable this code to works on various `tokenizers` + # targets. + if _truncation is None: + current = None + else: + current = {k: _truncation.get(k, None) for k in target} + + if current != target: + self._tokenizer.enable_truncation(**target) + + if padding_strategy == PaddingStrategy.DO_NOT_PAD: + if _padding is not None: + self._tokenizer.no_padding() + else: + length = max_length if padding_strategy == PaddingStrategy.MAX_LENGTH else None + target = { + "length": length, + "direction": padding_side if padding_side is not None else self.padding_side, + "pad_id": self.pad_token_id, + "pad_token": self.pad_token, + "pad_type_id": self.pad_token_type_id, + "pad_to_multiple_of": pad_to_multiple_of, + } + if _padding != target: + self._tokenizer.enable_padding(**target) + + def _batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + list[TextInput], list[TextInputPair], list[PreTokenizedInput], list[PreTokenizedInputPair] + ], + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + split_special_tokens: bool = False, + ) -> BatchEncoding: + if not isinstance(batch_text_or_text_pairs, (tuple, list)): + raise TypeError( + f"batch_text_or_text_pairs has to be a list or a tuple (got {type(batch_text_or_text_pairs)})" + ) + + # Set the truncation and padding strategy and restore the initial configuration + self.set_truncation_and_padding( + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + ) + + if self._tokenizer.encode_special_tokens != split_special_tokens: + self._tokenizer.encode_special_tokens = split_special_tokens + + encodings = self._tokenizer.encode_batch( + batch_text_or_text_pairs, + add_special_tokens=add_special_tokens, + is_pretokenized=is_split_into_words, + ) + + # Convert encoding to dict + # `Tokens` has type: tuple[ + # list[dict[str, list[list[int]]]] or list[dict[str, 2D-Tensor]], + # list[EncodingFast] + # ] + # with nested dimensions corresponding to batch, overflows, sequence length + tokens_and_encodings = [ + self._convert_encoding( + encoding=encoding, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + ) + for encoding in encodings + ] + + # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension + # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length) + # (we say ~ because the number of overflow varies with the example in the batch) + # + # To match each overflowing sample with the original sample in the batch + # we add an overflow_to_sample_mapping array (see below) + sanitized_tokens = {} + for key in tokens_and_encodings[0][0]: + stack = [e for item, _ in tokens_and_encodings for e in item[key]] + sanitized_tokens[key] = stack + sanitized_encodings = [e for _, item in tokens_and_encodings for e in item] + + # If returning overflowing tokens, we need to return a mapping + # from the batch idx to the original sample + if return_overflowing_tokens: + overflow_to_sample_mapping = [] + for i, (toks, _) in enumerate(tokens_and_encodings): + overflow_to_sample_mapping += [i] * len(toks["input_ids"]) + sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping + + for input_ids in sanitized_tokens["input_ids"]: + self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose) + return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors) + + def _encode_plus( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[bool] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + split_special_tokens: bool = False, + **kwargs, + ) -> BatchEncoding: + batched_input = [(text, text_pair)] if text_pair else [text] + batched_output = self._batch_encode_plus( + batched_input, + is_split_into_words=is_split_into_words, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + split_special_tokens=split_special_tokens, + **kwargs, + ) + + # Return tensor is None, then we can remove the leading batch axis + # Overflowing tokens are returned as a batch of output so we keep them in this case + if return_tensors is None and not return_overflowing_tokens: + batched_output = BatchEncoding( + { + key: (value[0] if len(value) > 0 and isinstance(value[0], list) else value) + for key, value in batched_output.items() + }, + batched_output.encodings, + ) + + self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose) + + return batched_output + + def convert_tokens_to_string(self, tokens: list[str]) -> str: + return ( + self.backend_tokenizer.decoder.decode(tokens) + if self.backend_tokenizer.decoder is not None + else " ".join(tokens) + ) + + def _decode( + self, + token_ids: Union[int, list[int]], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = None, + **kwargs, + ) -> str: + self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) + + if isinstance(token_ids, int): + token_ids = [token_ids] + text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) + + clean_up_tokenization_spaces = ( + clean_up_tokenization_spaces + if clean_up_tokenization_spaces is not None + else self.clean_up_tokenization_spaces + ) + if clean_up_tokenization_spaces: + clean_text = self.clean_up_tokenization(text) + return clean_text + else: + return text + + def _save_pretrained( + self, + save_directory: Union[str, os.PathLike], + file_names: tuple[str, ...], + legacy_format: Optional[bool] = None, + filename_prefix: Optional[str] = None, + ) -> tuple[str, ...]: + """ + Save a tokenizer using the slow-tokenizer/legacy format: vocabulary + added tokens as well as in a unique JSON + file containing {config + vocab + added-tokens}. + """ + save_directory = str(save_directory) + + if self.slow_tokenizer_class is None and legacy_format is True: + raise ValueError( + "Your tokenizer does not have a legacy version defined and therefore cannot register this version. You" + " might consider leaving the legacy_format at `None` or setting it to `False`." + ) + + save_slow = ( + (legacy_format is None or legacy_format is True) + and self.slow_tokenizer_class is not None + and self.can_save_slow_tokenizer + ) + save_fast = legacy_format is None or legacy_format is False + + if save_slow: + added_tokens_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE + ) + # make sure to be forward compatible + added_vocab = {tok: index for tok, index in self.added_tokens_encoder.items() if index >= self.vocab_size} + if added_vocab: + with open(added_tokens_file, "w", encoding="utf-8") as f: + out_str = json.dumps(added_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n" + f.write(out_str) + + vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix) + file_names = file_names + vocab_files + (added_tokens_file,) + + if save_fast: + tokenizer_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_FILE + ) + self.backend_tokenizer.save(tokenizer_file) + file_names = file_names + (tokenizer_file,) + + return file_names + + def train_new_from_iterator( + self, + text_iterator, + vocab_size, + length=None, + new_special_tokens=None, + special_tokens_map=None, + **kwargs, + ): + """ + Trains a tokenizer on a new corpus with the same defaults (in terms of special tokens or tokenization pipeline) + as the current one. + + Args: + text_iterator (generator of `list[str]`): + The training corpus. Should be a generator of batches of texts, for instance a list of lists of texts + if you have everything in memory. + vocab_size (`int`): + The size of the vocabulary you want for your tokenizer. + length (`int`, *optional*): + The total number of sequences in the iterator. This is used to provide meaningful progress tracking + new_special_tokens (list of `str` or `AddedToken`, *optional*): + A list of new special tokens to add to the tokenizer you are training. + special_tokens_map (`dict[str, str]`, *optional*): + If you want to rename some of the special tokens this tokenizer uses, pass along a mapping old special + token name to new special token name in this argument. + kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the trainer from the 🤗 Tokenizers library. + + Returns: + [`PreTrainedTokenizerFast`]: A new tokenizer of the same type as the original one, trained on + `text_iterator`. + + """ + tokenizer_json = json.loads(self._tokenizer.to_str()) + # Remove added tokens for now (uses IDs of tokens) + added_tokens = tokenizer_json.pop("added_tokens") + # Remove post processor for now (uses IDs of tokens) + post_processor = tokenizer_json.pop("post_processor") + + unk_token = None + # Remove vocab + if tokenizer_json["model"]["type"] == "BPE": + tokenizer_json["model"]["vocab"] = {} + tokenizer_json["model"]["merges"] = [] + elif tokenizer_json["model"]["type"] == "Unigram": + if tokenizer_json["model"]["unk_id"] is not None: + unk_id = tokenizer_json["model"]["unk_id"] + unk_token = tokenizer_json["model"]["vocab"][unk_id][0] + if special_tokens_map is not None and unk_token in special_tokens_map: + unk_token = special_tokens_map[unk_token] + tokenizer_json["model"]["unk_id"] = 0 + tokenizer_json["model"]["vocab"] = [[unk_token, 0.0]] + elif tokenizer_json["model"]["type"] in ["WordLevel", "WordPiece"]: + tokenizer_json["model"]["vocab"] = {} + else: + raise ValueError( + f"This method does not support this type of tokenizer (found {tokenizer_json['model']['type']}) " + "only BPE, Unigram, WordLevel and WordPiece." + ) + + if ( + special_tokens_map is not None + and "unk_token" in tokenizer_json["model"] + and tokenizer_json["model"]["unk_token"] in special_tokens_map + ): + tokenizer_json["model"]["unk_token"] = special_tokens_map[tokenizer_json["model"]["unk_token"]] + + tokenizer = TokenizerFast.from_str(json.dumps(tokenizer_json)) + + # Get the special tokens from the current tokenizer if none are specified. + special_tokens = [] + for added_token in added_tokens: + special = added_token.pop("special", None) + _ = added_token.pop("id", None) + if tokenizer_json["model"]["type"] != "Unigram" and not special: + continue + if special_tokens_map is not None and added_token["content"] in special_tokens_map: + added_token["content"] = special_tokens_map[added_token["content"]] + special_tokens.append(AddedToken(**added_token)) + + if new_special_tokens is not None: + special_tokens.extend(new_special_tokens) + + # Trainer needs to know the end of word / continuing subword thingies in BPE + if ( + tokenizer_json["model"]["type"] == "BPE" + and "continuing_subword_prefix" not in kwargs + and tokenizer_json["model"]["continuing_subword_prefix"] is not None + ): + kwargs["continuing_subword_prefix"] = tokenizer_json["model"]["continuing_subword_prefix"] + if ( + tokenizer_json["model"]["type"] == "BPE" + and "end_of_word_suffix" not in kwargs + and tokenizer_json["model"]["end_of_word_suffix"] is not None + ): + kwargs["end_of_word_suffix"] = tokenizer_json["model"]["end_of_word_suffix"] + if tokenizer_json["model"]["type"] == "Unigram" and unk_token is not None: + kwargs["unk_token"] = unk_token + if tokenizer_json["pre_tokenizer"] is not None: + if ( + tokenizer_json["pre_tokenizer"]["type"] == "ByteLevel" + or tokenizer_json["pre_tokenizer"]["type"] == "Sequence" + and "pretokenizers" in tokenizer_json["pre_tokenizer"] + and any( + pretokenizer["type"] == "ByteLevel" + for pretokenizer in tokenizer_json["pre_tokenizer"]["pretokenizers"] + ) + ): + kwargs["initial_alphabet"] = pre_tokenizers_fast.ByteLevel.alphabet() + + trainer_class = MODEL_TO_TRAINER_MAPPING[tokenizer_json["model"]["type"]] + trainer = trainer_class(vocab_size=vocab_size, special_tokens=special_tokens, **kwargs) + tokenizer.train_from_iterator(text_iterator, length=length, trainer=trainer) + + if post_processor is not None: + trained_tokenizer_json = json.loads(tokenizer.to_str()) + # Almost done, we just have to adjust the token IDs in the post processor + if "special_tokens" in post_processor: + for key in post_processor["special_tokens"]: + tokens = post_processor["special_tokens"][key]["tokens"] + if special_tokens_map is not None: + tokens = [special_tokens_map.get(token, token) for token in tokens] + post_processor["special_tokens"][key]["tokens"] = tokens + for token in tokens: + token_id = tokenizer.token_to_id(token) + if token_id is None: + raise ValueError( + "Attempted to set a token in the post processor that does not exist in the mapping" + ) + + post_processor["special_tokens"][key]["ids"] = [tokenizer.token_to_id(token) for token in tokens] + + for special_token in ["cls", "sep"]: + if special_token in post_processor: + token, _ = post_processor[special_token] + if special_tokens_map is not None and token in special_tokens_map: + token = special_tokens_map[token] + token_id = tokenizer.token_to_id(token) + if token_id is None: + raise ValueError( + "Attempted to set a token in the post processor that does not exist in the mapping" + ) + post_processor[special_token] = [token, token_id] + + trained_tokenizer_json["post_processor"] = post_processor + tokenizer = TokenizerFast.from_str(json.dumps(trained_tokenizer_json)) + + kwargs = self.init_kwargs.copy() + # Map pad/cls/mask token at the Transformers level + special_tokens_list = SpecialTokensMixin.SPECIAL_TOKENS_ATTRIBUTES.copy() + special_tokens_list.remove("additional_special_tokens") + for token in special_tokens_list: + if getattr(self, token) is not None: + special_token = getattr(self, token) + if special_tokens_map is not None and special_token in special_tokens_map: + special_token = special_tokens_map[special_token] + + special_token_full = self._special_tokens_map.get(token, None) + if isinstance(special_token_full, AddedToken): + # Create an added token with the same parameters except the content + kwargs[token] = AddedToken( + special_token, + single_word=special_token_full.single_word, + lstrip=special_token_full.lstrip, + rstrip=special_token_full.rstrip, + normalized=special_token_full.normalized, + special=True, + ) + else: + kwargs[token] = special_token + + additional_special_tokens = self.additional_special_tokens + if new_special_tokens is not None: + additional_special_tokens.extend(new_special_tokens) + if len(additional_special_tokens) > 0: + kwargs["additional_special_tokens"] = additional_special_tokens + + return self.__class__(tokenizer_object=tokenizer, **kwargs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/trainer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..3bfb72b4eaf52dbf37b3bfdcab10531d484ddf13 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/trainer.py @@ -0,0 +1,5723 @@ +# Copyright 2020-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. +""" + +import contextlib +import copy +import functools +import glob +import importlib.metadata +import inspect +import json +import math +import os +import random +import re +import shutil +import sys +import tempfile +import time +import warnings +from collections.abc import Iterator, Mapping +from functools import partial +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Optional, Union + + +# Integrations must be imported before ML frameworks: +# ruff: isort: off +from .integrations import ( + get_reporting_integration_callbacks, +) + +# ruff: isort: on + +import huggingface_hub.utils as hf_hub_utils +import numpy as np +import safetensors.torch +import torch +import torch.distributed as dist +from huggingface_hub import ModelCard, create_repo, upload_folder +from packaging import version +from torch import nn +from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler + +from . import __version__ +from .configuration_utils import PretrainedConfig +from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator +from .debug_utils import DebugOption, DebugUnderflowOverflow +from .feature_extraction_sequence_utils import SequenceFeatureExtractor +from .feature_extraction_utils import FeatureExtractionMixin +from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend +from .image_processing_utils import BaseImageProcessor +from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available +from .integrations.tpu import tpu_spmd_dataloader +from .modelcard import TrainingSummary +from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model +from .models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_MAPPING_NAMES, +) +from .optimization import Adafactor, get_scheduler +from .processing_utils import ProcessorMixin +from .pytorch_utils import ( + is_torch_greater_or_equal_than_2_3, +) +from .tokenization_utils_base import PreTrainedTokenizerBase +from .trainer_callback import ( + CallbackHandler, + DefaultFlowCallback, + ExportableState, + PrinterCallback, + ProgressCallback, + TrainerCallback, + TrainerControl, + TrainerState, +) +from .trainer_pt_utils import ( + DistributedTensorGatherer, + EvalLoopContainer, + IterableDatasetShard, + LabelSmoother, + LayerWiseDummyOptimizer, + LengthGroupedSampler, + SequentialDistributedSampler, + distributed_broadcast_scalars, + distributed_concat, + find_batch_size, + get_model_param_count, + get_module_class_from_name, + get_parameter_names, + nested_concat, + nested_detach, + nested_numpify, + nested_xla_mesh_reduce, + reissue_pt_warnings, + remove_dummy_checkpoint, + set_rng_state_for_device, +) +from .trainer_utils import ( + PREFIX_CHECKPOINT_DIR, + BestRun, + EvalLoopOutput, + EvalPrediction, + HPSearchBackend, + HubStrategy, + PredictionOutput, + RemoveColumnsCollator, + SaveStrategy, + TrainerMemoryTracker, + TrainOutput, + check_target_module_exists, + default_compute_objective, + denumpify_detensorize, + enable_full_determinism, + find_executable_batch_size, + get_last_checkpoint, + has_length, + neftune_post_forward_hook, + number_of_arguments, + seed_worker, + set_seed, + speed_metrics, +) +from .training_args import OptimizerNames, ParallelMode, TrainingArguments +from .utils import ( + ADAPTER_CONFIG_NAME, + ADAPTER_SAFE_WEIGHTS_NAME, + ADAPTER_WEIGHTS_NAME, + CONFIG_NAME, + GENERATION_CONFIG_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + XLA_FSDPV2_MIN_VERSION, + PushInProgress, + PushToHubMixin, + can_return_loss, + check_torch_load_is_safe, + find_labels, + is_accelerate_available, + is_apollo_torch_available, + is_bitsandbytes_available, + is_datasets_available, + is_galore_torch_available, + is_grokadamw_available, + is_in_notebook, + is_liger_kernel_available, + is_lomo_available, + is_peft_available, + is_sagemaker_dp_enabled, + is_sagemaker_mp_enabled, + is_schedulefree_available, + is_torch_hpu_available, + is_torch_mlu_available, + is_torch_mps_available, + is_torch_musa_available, + is_torch_neuroncore_available, + is_torch_npu_available, + is_torch_optimi_available, + is_torch_xla_available, + is_torch_xpu_available, + is_torchao_available, + logging, + strtobool, +) +from .utils.deprecation import deprecate_kwarg +from .utils.import_utils import requires +from .utils.quantization_config import QuantizationMethod + + +DEFAULT_CALLBACKS = [DefaultFlowCallback] +DEFAULT_PROGRESS_CALLBACK = ProgressCallback + +if is_in_notebook(): + from .utils.notebook import NotebookProgressCallback + + DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback + +if is_datasets_available(): + import datasets + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + import torch_xla.debug.metrics as met + import torch_xla.runtime as xr + from torch_xla import __version__ as XLA_VERSION + + IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION) + if IS_XLA_FSDPV2_POST_2_2: + import torch_xla.distributed.spmd as xs +else: + IS_XLA_FSDPV2_POST_2_2 = False + + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + from smdistributed.modelparallel import __version__ as SMP_VERSION + + IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") + + from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat +else: + IS_SAGEMAKER_MP_POST_1_10 = False + +if is_peft_available(): + from peft import PeftModel + +if is_accelerate_available(): + from accelerate import Accelerator, skip_first_batches + from accelerate import __version__ as accelerate_version + from accelerate.state import AcceleratorState + from accelerate.utils import ( + AutocastKwargs, + DistributedDataParallelKwargs, + DistributedType, + load_fsdp_model, + load_fsdp_optimizer, + save_fsdp_model, + save_fsdp_optimizer, + ) + + DATA_SAMPLERS = [RandomSampler] + if version.parse(accelerate_version) > version.parse("1.3.0"): + from accelerate.utils import TorchTensorParallelPlugin + from accelerate.data_loader import SeedableRandomSampler + + DATA_SAMPLERS += [SeedableRandomSampler] + + if is_deepspeed_available(): + from accelerate.utils import DeepSpeedSchedulerWrapper + +if is_accelerate_available("0.28.0"): + from accelerate.utils import DataLoaderConfiguration + + +def _is_peft_model(model): + if is_peft_available(): + classes_to_check = (PeftModel,) + # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321 + if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"): + from peft import PeftMixedModel + + classes_to_check = (*classes_to_check, PeftMixedModel) + return isinstance(model, classes_to_check) + return False + + +def _get_fsdp_ckpt_kwargs(): + # TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release + if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters): + return {"adapter_only": True} + else: + return {} + + +def safe_globals(): + # Starting from version 2.4 PyTorch introduces a check for the objects loaded + # with torch.load(weights_only=True). Starting from 2.6 weights_only=True becomes + # a default and requires allowlisting of objects being loaded. + # See: https://github.com/pytorch/pytorch/pull/137602 + # See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals + # See: https://github.com/huggingface/accelerate/pull/3036 + if version.parse(torch.__version__).release < version.parse("2.6").release: + return contextlib.nullcontext() + + np_core = np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core + allowlist = [np_core.multiarray._reconstruct, np.ndarray, np.dtype] + # numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for + # all versions of numpy + allowlist += [type(np.dtype(np.uint32))] + + return torch.serialization.safe_globals(allowlist) + + +if TYPE_CHECKING: + import optuna + +logger = logging.get_logger(__name__) + + +# Name of the files used for checkpointing +TRAINING_ARGS_NAME = "training_args.bin" +TRAINER_STATE_NAME = "trainer_state.json" +OPTIMIZER_NAME = "optimizer.pt" +SCALER_NAME = "scaler.pt" +OPTIMIZER_NAME_BIN = "optimizer.bin" +SCHEDULER_NAME = "scheduler.pt" +FSDP_MODEL_NAME = "pytorch_model_fsdp" + + +@requires( + backends=( + "torch", + "accelerate", + ) +) +class Trainer: + """ + Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers. + + Args: + model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*): + The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed. + + + + [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use + your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers + models. + + + + args ([`TrainingArguments`], *optional*): + The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the + `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided. + data_collator (`DataCollator`, *optional*): + The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will + default to [`default_data_collator`] if no `processing_class` is provided, an instance of + [`DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or tokenizer. + train_dataset (Union[`torch.utils.data.Dataset`, `torch.utils.data.IterableDataset`, `datasets.Dataset`], *optional*): + The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. + + Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a + distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a + `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will + manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally + sets the seed of the RNGs used. + eval_dataset (Union[`torch.utils.data.Dataset`, dict[str, `torch.utils.data.Dataset`, `datasets.Dataset`]), *optional*): + The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each + dataset prepending the dictionary key to the metric name. + processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + This supersedes the `tokenizer` argument, which is now deprecated. + model_init (`Callable[[], PreTrainedModel]`, *optional*): + A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start + from a new instance of the model as given by this function. + + The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to + be able to choose different architectures according to hyper parameters (such as layer count, sizes of + inner layers, dropout probabilities etc). + compute_loss_func (`Callable`, *optional*): + A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated + batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss function](https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618) used by [`Trainer`]. + compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): + The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return + a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to + `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered + after the last eval batch to signal that the function needs to calculate and return the global summary + statistics rather than accumulating the batch-level statistics + callbacks (List of [`TrainerCallback`], *optional*): + A list of callbacks to customize the training loop. Will add those to the list of default callbacks + detailed in [here](callback). + + If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your + model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. + optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], dict[str, Any]]`, *optional*): + A tuple containing the optimizer class and keyword arguments to use. + Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument. + + Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): + A function that preprocess the logits right before caching them at each evaluation step. Must take two + tensors, the logits and the labels, and return the logits once processed as desired. The modifications made + by this function will be reflected in the predictions received by `compute_metrics`. + + Note that the labels (second parameter) will be `None` if the dataset does not have them. + + Important attributes: + + - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`] + subclass. + - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the + original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`, + the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner + model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`. + - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from + data parallelism, this means some of the model layers are split on different GPUs). + - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set + to `False` if model parallel or deepspeed is used, or if the default + `TrainingArguments.place_model_on_device` is overridden to return `False` . + - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while + in `train`) + + """ + + # Those are used as methods of the Trainer in examples. + from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state + + @deprecate_kwarg("tokenizer", new_name="processing_class", version="5.0.0", raise_if_both_names=True) + def __init__( + self, + model: Union[PreTrainedModel, nn.Module, None] = None, + args: Optional[TrainingArguments] = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, + model_init: Optional[Callable[..., PreTrainedModel]] = None, + compute_loss_func: Optional[Callable] = None, + compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, + callbacks: Optional[list[TrainerCallback]] = None, + optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), + optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None, + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + ): + if args is None: + output_dir = "tmp_trainer" + logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") + args = TrainingArguments(output_dir=output_dir) + if args.batch_eval_metrics and compute_metrics is not None: + if "compute_result" not in inspect.signature(compute_metrics).parameters: + raise ValueError( + "When using `batch_eval_metrics`, your `compute_metrics` function must take a `compute_result`" + " boolean argument which will be triggered after the last batch of the eval set to signal that the" + " summary statistics should be returned by the function." + ) + if args.eval_strategy is not None and args.eval_strategy != "no" and eval_dataset is None: + raise ValueError( + f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. " + ) + if args.save_strategy == SaveStrategy.BEST or args.load_best_model_at_end: + if args.metric_for_best_model is None: + raise ValueError( + "`args.metric_for_best_model` must be provided when using 'best' save_strategy or if `args.load_best_model_at_end` is set to `True`." + ) + + self.args = args + self.compute_loss_func = compute_loss_func + # Seed must be set before instantiating the model when using model + enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) + + self.hp_name = None + self.deepspeed = None + self.is_in_train = False + self.model = model + self.create_accelerator_and_postprocess() + + # memory metrics - must set up as early as possible + self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) + self._memory_tracker.start() + + # set the correct log level depending on the node + log_level = args.get_process_log_level() + logging.set_verbosity(log_level) + + # force device and distributed setup init explicitly + args._setup_devices + + if model is None: + if model_init is not None: + self.model_init = model_init + model = self.call_model_init() + else: + raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument") + else: + if model_init is not None: + warnings.warn( + "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will" + " overwrite your model when calling the `train` method. This will become a fatal error in the next" + " release.", + FutureWarning, + ) + self.model_init = model_init + + if model.__class__.__name__ in MODEL_MAPPING_NAMES: + raise ValueError( + f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only " + "computes hidden states and does not accept any labels. You should choose a model with a head " + "suitable for your task like any of the `AutoModelForXxx` listed at " + "https://huggingface.co/docs/transformers/model_doc/auto" + ) + + if getattr(model, "is_parallelizable", False) and getattr(model, "model_parallel", False): + self.is_model_parallel = True + else: + self.is_model_parallel = False + + if getattr(model, "hf_device_map", None) is not None: + devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]] + if len(devices) > 1: + self.is_model_parallel = True + elif len(devices) == 1: + self.is_model_parallel = self.args.device != torch.device(devices[0]) + else: + self.is_model_parallel = False + + # warn users + if self.is_model_parallel: + logger.info( + "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set" + " to `True` to avoid any unexpected behavior such as device placement mismatching." + ) + + if self.args.use_liger_kernel: + if is_liger_kernel_available(): + from liger_kernel.transformers import _apply_liger_kernel_to_instance + + # Prepare kernel config - use provided config or default (empty dict for default behavior) + kernel_config = self.args.liger_kernel_config if self.args.liger_kernel_config is not None else {} + + if isinstance(model, PreTrainedModel): + # Patch the model with liger kernels. Use the specified or default kernel configurations. + _apply_liger_kernel_to_instance(model=model, **kernel_config) + elif hasattr(model, "get_base_model") and isinstance(model.get_base_model(), PreTrainedModel): + # Patch the base model with liger kernels where model is a PeftModel. Use the specified or default kernel configurations. + _apply_liger_kernel_to_instance(model=model.get_base_model(), **kernel_config) + else: + logger.warning( + "The model is not an instance of PreTrainedModel. No liger kernels will be applied." + ) + else: + raise ImportError( + "You have set `use_liger_kernel` to `True` but liger-kernel >= 0.3.0 is not available. " + "Please install it with `pip install liger-kernel`" + ) + + _is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr( + model, "_hf_peft_config_loaded", False + ) + _quantization_method_supports_training = ( + getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable + ) + + _is_model_quantized_and_qat_trainable = getattr(model, "hf_quantizer", None) is not None and getattr( + model.hf_quantizer, "is_qat_trainable", False + ) + + # Filter out quantized + compiled models + if _is_quantized_and_base_model and hasattr(model, "_orig_mod"): + raise ValueError( + "You cannot fine-tune quantized model with `torch.compile()` make sure to pass a non-compiled model when fine-tuning a quantized model with PEFT" + ) + + # At this stage the model is already loaded + if _is_quantized_and_base_model and not _is_peft_model(model) and not _is_model_quantized_and_qat_trainable: + raise ValueError( + "You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of" + " the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft" + " for more details" + ) + elif _is_quantized_and_base_model and not _quantization_method_supports_training: + raise ValueError( + f"The model you are trying to fine-tune is quantized with {model.hf_quantizer.quantization_config.quant_method}" + " but that quantization method do not support training. Please open an issue on GitHub: https://github.com/huggingface/transformers" + f" to request the support for training support for {model.hf_quantizer.quantization_config.quant_method}" + ) + + self.is_fsdp_xla_enabled = args.fsdp_config["xla"] + if len(args.fsdp) > 0: + if self.is_deepspeed_enabled: + raise ValueError( + "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags." + ) + if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED: + raise ValueError("Using fsdp only works in distributed training.") + + # one place to sort out whether to place the model on device or not + # postpone switching model to cuda when: + # 1. MP - since we are trying to fit a much bigger than 1 gpu model + # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, + # and we only use deepspeed for training at the moment + # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first + # 4. FSDP - same as MP + self.place_model_on_device = args.place_model_on_device + if ( + self.is_model_parallel + or self.is_deepspeed_enabled + or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) + or self.is_fsdp_xla_enabled + or self.is_fsdp_enabled + ): + self.place_model_on_device = False + + default_collator = ( + DataCollatorWithPadding(processing_class) + if processing_class is not None + and isinstance(processing_class, (PreTrainedTokenizerBase, SequenceFeatureExtractor)) + else default_data_collator + ) + self.data_collator = data_collator if data_collator is not None else default_collator + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self.processing_class = processing_class + + # Bnb Quantized models doesn't support `.to` operation. + if ( + self.place_model_on_device + and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES + ): + self._move_model_to_device(model, args.device) + + # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs + if self.is_model_parallel: + self.args._n_gpu = 1 + + # later use `self.model is self.model_wrapped` to check if it's wrapped or not + self.model_wrapped = model + self.model = model + + # Just in case the model was wrapped outside of the `Trainer` + unwrapped_model = self.accelerator.unwrap_model(model) + # We also unwrap peft model + if _is_peft_model(unwrapped_model): + if hasattr(unwrapped_model, "get_base_model"): + unwrapped_model = unwrapped_model.get_base_model() + elif hasattr(unwrapped_model, "base_model") and hasattr(unwrapped_model.base_model, "model"): + unwrapped_model = unwrapped_model.base_model.model + else: + raise AttributeError("Cannot extract base model safely from this PEFT wrapper.") + + # Check if the model has explicit setup for loss kwargs, + # if not, check if `**kwargs` are in model.forward + if hasattr(unwrapped_model, "accepts_loss_kwargs"): + self.model_accepts_loss_kwargs = unwrapped_model.accepts_loss_kwargs + else: + forward_params = inspect.signature(unwrapped_model.forward).parameters + self.model_accepts_loss_kwargs = any( + k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values() + ) + + self.neftune_noise_alpha = args.neftune_noise_alpha + + self.compute_metrics = compute_metrics + self.preprocess_logits_for_metrics = preprocess_logits_for_metrics + self.optimizer, self.lr_scheduler = optimizers + self.optimizer_cls_and_kwargs = optimizer_cls_and_kwargs + if self.optimizer_cls_and_kwargs is not None and self.optimizer is not None: + raise RuntimeError("Passing both `optimizers` and `optimizer_cls_and_kwargs` arguments is incompatible.") + if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): + raise RuntimeError( + "Passing a `model_init` is incompatible with providing the `optimizers` argument. " + "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." + ) + if is_torch_xla_available() and self.optimizer is not None: + for param in self.model.parameters(): + model_device = param.device + break + for param_group in self.optimizer.param_groups: + if len(param_group["params"]) > 0: + optimizer_device = param_group["params"][0].device + break + if model_device != optimizer_device: + raise ValueError( + "The model and the optimizer parameters are not on the same device, which probably means you" + " created an optimizer around your model **before** putting on the device and passing it to the" + " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" + " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." + ) + if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and ( + self.optimizer is not None or self.lr_scheduler is not None + ): + raise RuntimeError( + "Passing `optimizers` is not allowed if PyTorch FSDP is enabled. " + "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." + ) + default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) + callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks + self.callback_handler = CallbackHandler( + callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler + ) + self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) + + # Will be set to True by `self._setup_loggers()` on first call to `self.log()`. + self._loggers_initialized = False + + # Create distant repo and output directory if needed + self.hub_model_id = None + if self.args.push_to_hub: + self.init_hf_repo() + if self.args.should_save: + os.makedirs(self.args.output_dir, exist_ok=True) + + if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): + raise TypeError("The `data_collator` should be a simple callable (function, class with `__call__`).") + + if args.max_steps > 0 and args.num_train_epochs > 0: + logger.info("max_steps is given, it will override any value given in num_train_epochs") + + if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0: + raise ValueError( + "The train_dataset does not implement __len__, max_steps has to be specified. " + "The number of steps needs to be known in advance for the learning rate scheduler." + ) + + if ( + train_dataset is not None + and isinstance(train_dataset, torch.utils.data.IterableDataset) + and args.group_by_length + ): + raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset") + + self._signature_columns = None + + # Mixed precision setup + self.use_apex = False + self.use_cpu_amp = False + + # Mixed precision setup for SageMaker Model Parallel + if is_sagemaker_mp_enabled(): + # BF16 + model parallelism in SageMaker: currently not supported, raise an error + if args.bf16: + raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ") + + if IS_SAGEMAKER_MP_POST_1_10: + # When there's mismatch between SMP config and trainer argument, use SMP config as truth + if args.fp16 != smp.state.cfg.fp16: + logger.warning( + f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " + f"but FP16 provided in trainer argument is {args.fp16}, " + f"setting to {smp.state.cfg.fp16}" + ) + args.fp16 = smp.state.cfg.fp16 + else: + # smp < 1.10 does not support fp16 in trainer. + if hasattr(smp.state.cfg, "fp16"): + logger.warning( + f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " + "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer." + ) + if (args.fp16 or args.bf16) and args.half_precision_backend == "auto": + if args.device == torch.device("cpu"): + if args.fp16: + if not is_torch_greater_or_equal_than_2_3: + raise ValueError("Tried to use `fp16` but it is not supported on cpu") + else: + args.half_precision_backend = "cpu_amp" + logger.info(f"Using {args.half_precision_backend} half precision backend") + + if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()): + # deepspeed and SageMaker Model Parallel manage their own half precision + if args.half_precision_backend == "cpu_amp": + self.use_cpu_amp = True + self.amp_dtype = torch.bfloat16 + elif args.half_precision_backend == "apex": + self.use_apex = True + + # Label smoothing + if self.args.label_smoothing_factor != 0: + self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) + else: + self.label_smoother = None + + # Check for multi-label classification incompatibility + if self.args.label_smoothing_factor > 0: + if getattr(self.model.config, "problem_type", None) == "multi_label_classification": + warnings.warn( + "Label smoothing is not compatible with multi-label classification. " + "Disabling label smoothing for this training run.", + UserWarning, + ) + self.label_smoother = None + + self.control = TrainerControl() + + self.state = TrainerState( + is_local_process_zero=self.is_local_process_zero(), + is_world_process_zero=self.is_world_process_zero(), + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ], + ) + # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then + # returned to 0 every time flos need to be logged + self.current_flos = 0 + self.hp_search_backend = None + + model_to_inspect = self.model + if _is_peft_model(self.model): + if hasattr(self.model, "get_base_model"): + model_to_inspect = self.model.get_base_model() + else: + # PeftMixedModel do not provide a `get_base_model` method + model_to_inspect = self.model.base_model.model + default_label_names = find_labels(model_to_inspect.__class__) + self.label_names = default_label_names if self.args.label_names is None else self.args.label_names + self.can_return_loss = can_return_loss(model_to_inspect.__class__) + self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) + + # Internal variables to help with automatic batch size reduction + self._train_batch_size = args.train_batch_size + self._created_lr_scheduler = False + + # very last + self._memory_tracker.stop_and_update_metrics() + + self.is_fsdp_xla_v2_enabled = args.fsdp_config.get("xla_fsdp_v2", False) + if self.is_fsdp_xla_v2_enabled: + if not IS_XLA_FSDPV2_POST_2_2: + raise ValueError("FSDPv2 requires `torch_xla` 2.2 or higher.") + # Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper. + # Tensor axis is just a placeholder where it will not be used in FSDPv2. + num_devices = xr.global_runtime_device_count() + xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor"))) + self.is_fsdp_xla_v1_enabled = self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled + + @property + def tokenizer(self) -> Optional[PreTrainedTokenizerBase]: + logger.warning("Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.") + return self.processing_class + + @tokenizer.setter + def tokenizer(self, processing_class) -> None: + logger.warning( + "Trainer.tokenizer is now deprecated. You should use `Trainer.processing_class = processing_class` instead." + ) + self.processing_class = processing_class + + def _activate_neftune(self, model): + r""" + Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: + https://huggingface.co/papers/2310.05914 + """ + unwrapped_model = self.accelerator.unwrap_model(model) + + if _is_peft_model(unwrapped_model): + embeddings = unwrapped_model.base_model.model.get_input_embeddings() + else: + embeddings = unwrapped_model.get_input_embeddings() + + del unwrapped_model + + embeddings.neftune_noise_alpha = self.neftune_noise_alpha + hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook) + self.neftune_hook_handle = hook_handle + return model + + def _deactivate_neftune(self, model): + """ + Deactivates the neftune method. Make sure to call `_activate_neftune` first. + """ + if not hasattr(self, "neftune_hook_handle"): + raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first") + + unwrapped_model = self.accelerator.unwrap_model(model) + + if _is_peft_model(unwrapped_model): + embeddings = unwrapped_model.base_model.model.get_input_embeddings() + else: + embeddings = unwrapped_model.get_input_embeddings() + + self.neftune_hook_handle.remove() + del embeddings.neftune_noise_alpha, unwrapped_model + + def add_callback(self, callback): + """ + Add a callback to the current list of [`~transformers.TrainerCallback`]. + + Args: + callback (`type` or [`~transformers.TrainerCallback]`): + A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the + first case, will instantiate a member of that class. + """ + self.callback_handler.add_callback(callback) + + def pop_callback(self, callback): + """ + Remove a callback from the current list of [`~transformers.TrainerCallback`] and returns it. + + If the callback is not found, returns `None` (and no error is raised). + + Args: + callback (`type` or [`~transformers.TrainerCallback]`): + A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the + first case, will pop the first member of that class found in the list of callbacks. + + Returns: + [`~transformers.TrainerCallback`]: The callback removed, if found. + """ + return self.callback_handler.pop_callback(callback) + + def remove_callback(self, callback): + """ + Remove a callback from the current list of [`~transformers.TrainerCallback`]. + + Args: + callback (`type` or [`~transformers.TrainerCallback]`): + A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the + first case, will remove the first member of that class found in the list of callbacks. + """ + self.callback_handler.remove_callback(callback) + + def _move_model_to_device(self, model, device): + if getattr(model, "hf_device_map", None) is not None: + logger.warning( + "The model is already on multiple devices. Skipping the move to device specified in `args`." + ) + return + model = model.to(device) + # Moving a model to an XLA device disconnects the tied weights, so we have to retie them. + if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"): + model.tie_weights() + + def _align_special_tokens(self): + """ + Aligns the special tokens of the tokenizer with the model configs. + + A new tokens may be defined in the tokenizer for fine-tuning purposes, e.g. an "end of turn" token may be + added on chat models. In that case, we want the model configs to be aligned with the tokenizer, so that all + downstream uses work as expected. This alignment should happen before training, to ensure the prediction step + uses the new tokens as well. + """ + if isinstance(self.processing_class, ProcessorMixin): + tokenizer: PreTrainedTokenizerBase = self.processing_class.tokenizer + else: + tokenizer = self.processing_class + model_has_generation_config = ( + hasattr(self.model, "generation_config") and self.model.generation_config is not None + ) + updated_tokens = {} + + # 1 - Align EOS token. EOS is more complex than the others, as `generation_config` may hold more than one EOS + # token. + tokenizer_has_new_eos = tokenizer.eos_token_id != self.model.config.eos_token_id + if model_has_generation_config: + # `generation_config.eos_token_id` is None: direct comparison + if self.model.generation_config.eos_token_id is None: + tokenizer_has_new_eos |= tokenizer.eos_token_id != self.model.generation_config.eos_token_id + else: + # `generation_config.eos_token_id` is an `int`: convert it to list (and continue below) + if isinstance(self.model.generation_config.eos_token_id, int): + self.model.generation_config.eos_token_id = [self.model.generation_config.eos_token_id] + # `generation_config.eos_token_id` is a `list`: check if the tokenizer's EOS token is in the list + tokenizer_has_new_eos |= tokenizer.eos_token_id not in self.model.generation_config.eos_token_id + + if tokenizer_has_new_eos: + updated_tokens["eos_token_id"] = tokenizer.eos_token_id + self.model.config.eos_token_id = tokenizer.eos_token_id + # The generation config may hold more than one EOS token. We preserve the original EOS tokens: any of the + # EOS tokens defined here will halt generation. + if model_has_generation_config: + all_eos_tokens = [tokenizer.eos_token_id] + if self.model.generation_config.eos_token_id is not None: + all_eos_tokens += list(self.model.generation_config.eos_token_id) + self.model.generation_config.eos_token_id = [token for token in all_eos_tokens if token is not None] + + # 2 - Align BOS + tokenizer_has_new_bos = tokenizer.bos_token_id != self.model.config.bos_token_id + if model_has_generation_config: + tokenizer_has_new_bos |= tokenizer.bos_token_id != self.model.generation_config.bos_token_id + + if tokenizer_has_new_bos: + updated_tokens["bos_token_id"] = tokenizer.bos_token_id + self.model.config.bos_token_id = tokenizer.bos_token_id + if model_has_generation_config: + self.model.generation_config.bos_token_id = tokenizer.bos_token_id + + # 3 - Align PAD + tokenizer_has_new_pad = tokenizer.pad_token_id != self.model.config.pad_token_id + if model_has_generation_config: + tokenizer_has_new_pad |= tokenizer.pad_token_id != self.model.generation_config.pad_token_id + + if tokenizer_has_new_pad: + updated_tokens["pad_token_id"] = tokenizer.pad_token_id + self.model.config.pad_token_id = tokenizer.pad_token_id + if model_has_generation_config: + self.model.generation_config.pad_token_id = tokenizer.pad_token_id + + # 4 - Warn users about the changes + if len(updated_tokens) > 0: + logger.warning( + "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. " + "The model config and generation config were aligned accordingly, being updated with the tokenizer's " + f"values. Updated tokens: {updated_tokens}." + ) + + def _set_signature_columns_if_needed(self): + if self._signature_columns is None: + # Inspect model forward signature to keep only the arguments it accepts. + model_to_inspect = self.model + if _is_peft_model(self.model): + if hasattr(self.model, "get_base_model"): + model_to_inspect = self.model.get_base_model() + else: + # PeftMixedModel do not provide a `get_base_model` method + model_to_inspect = self.model.base_model.model + signature = inspect.signature(model_to_inspect.forward) + self._signature_columns = list(signature.parameters.keys()) + # Labels may be named label or label_ids, the default data collator handles that. + self._signature_columns += list(set(["label", "label_ids"] + self.label_names)) + + def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): + if not self.args.remove_unused_columns: + return dataset + self._set_signature_columns_if_needed() + signature_columns = self._signature_columns + + ignored_columns = list(set(dataset.column_names) - set(signature_columns)) + if len(ignored_columns) > 0: + dset_description = "" if description is None else f"in the {description} set" + logger.info( + f"The following columns {dset_description} don't have a corresponding argument in " + f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}." + f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, " + " you can safely ignore this message." + ) + + columns = [k for k in signature_columns if k in dataset.column_names] + if len(columns) == 0: + raise ValueError( + f"No columns in the dataset match the model's forward method signature: ({', '.join(signature_columns)}). " + f"The following columns have been ignored: [{', '.join(ignored_columns)}]. " + "Please check the dataset and model. You may need to set `remove_unused_columns=False` in `TrainingArguments`." + ) + + if version.parse(datasets.__version__) < version.parse("1.4.0"): + dataset.set_format( + type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"] + ) + return dataset + else: + return dataset.remove_columns(ignored_columns) + + def _get_collator_with_removed_columns( + self, data_collator: Callable, description: Optional[str] = None + ) -> Callable: + """Wrap the data collator in a callable removing unused columns.""" + if not self.args.remove_unused_columns: + return data_collator + self._set_signature_columns_if_needed() + signature_columns = self._signature_columns + + remove_columns_collator = RemoveColumnsCollator( + data_collator=data_collator, + signature_columns=signature_columns, + logger=logger, + description=description, + model_name=self.model.__class__.__name__, + ) + return remove_columns_collator + + def _get_train_sampler(self, train_dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]: + if train_dataset is None: + train_dataset = self.train_dataset + if train_dataset is None or not has_length(train_dataset): + return None + + # Build the sampler. + if self.args.group_by_length: + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + lengths = ( + train_dataset[self.args.length_column_name] + if self.args.length_column_name in train_dataset.column_names + else None + ) + else: + lengths = None + model_input_name = ( + self.processing_class.model_input_names[0] if self.processing_class is not None else None + ) + return LengthGroupedSampler( + self.args.train_batch_size * self.args.gradient_accumulation_steps, + dataset=train_dataset, + lengths=lengths, + model_input_name=model_input_name, + ) + + else: + return RandomSampler(train_dataset) + + def _get_dataloader( + self, + dataset: Dataset, + description: str, + batch_size: int, + sampler_fn: Optional[Callable[[Dataset], torch.utils.data.Sampler]] = None, + is_training: bool = False, + dataloader_key: Optional[str] = None, + ) -> DataLoader: + """Create a [`~torch.utils.data.DataLoader`] from the given dataset.""" + + data_collator = self.data_collator + if is_datasets_available() and isinstance(dataset, datasets.Dataset): + dataset = self._remove_unused_columns(dataset, description=description) + else: + data_collator = self._get_collator_with_removed_columns(self.data_collator, description=description) + + dataloader_params = { + "batch_size": batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(dataset, torch.utils.data.IterableDataset): + if sampler_fn is not None: + dataloader_params["sampler"] = sampler_fn(dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + if is_training: + dataloader_params["worker_init_fn"] = partial( + seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index + ) + + dataloader = self.accelerator.prepare(DataLoader(dataset, **dataloader_params)) + + # Store the prepared dataloader for subsequent evaluations if using persistent workers. + if dataloader_key is not None and self.args.dataloader_persistent_workers: + if hasattr(self, "_eval_dataloaders"): + self._eval_dataloaders[dataloader_key] = dataloader + else: + self._eval_dataloaders = {dataloader_key: dataloader} + + return dataloader + + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + training if necessary) otherwise. + + Subclass and override this method if you want to inject some custom behavior. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + return self._get_dataloader( + dataset=self.train_dataset, + description="Training", + batch_size=self._train_batch_size, + sampler_fn=self._get_train_sampler, + is_training=True, + ) + + def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: + if eval_dataset is None or not has_length(eval_dataset): + return None + # Build the sampler. + + # Deprecated code + if self.args.use_legacy_prediction_loop: + if is_torch_xla_available(): + return SequentialDistributedSampler( + eval_dataset, num_replicas=xr.world_size(), rank=xr.global_ordinal() + ) + elif is_sagemaker_mp_enabled(): + return SequentialDistributedSampler( + eval_dataset, + num_replicas=smp.dp_size(), + rank=smp.dp_rank(), + batch_size=self.args.per_device_eval_batch_size, + ) + else: + return SequentialSampler(eval_dataset) + + if self.args.group_by_length: + if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): + lengths = ( + eval_dataset[self.args.length_column_name] + if self.args.length_column_name in eval_dataset.column_names + else None + ) + else: + lengths = None + model_input_name = ( + self.processing_class.model_input_names[0] if self.processing_class is not None else None + ) + return LengthGroupedSampler( + self.args.eval_batch_size, + dataset=eval_dataset, + lengths=lengths, + model_input_name=model_input_name, + ) + + if self.args.world_size <= 1: + return SequentialSampler(eval_dataset) + else: + return None + + def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader: + """ + Returns the evaluation [`~torch.utils.data.DataLoader`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*): + If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + + # If we have persistent workers, don't do a fork bomb especially as eval datasets + # don't change during training + dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval" + if ( + hasattr(self, "_eval_dataloaders") + and dataloader_key in self._eval_dataloaders + and self.args.dataloader_persistent_workers + ): + return self._eval_dataloaders[dataloader_key] + + eval_dataset = ( + self.eval_dataset[eval_dataset] + if isinstance(eval_dataset, str) + else eval_dataset + if eval_dataset is not None + else self.eval_dataset + ) + + return self._get_dataloader( + dataset=eval_dataset, + description="Evaluation", + batch_size=self.args.eval_batch_size, + sampler_fn=self._get_eval_sampler, + dataloader_key=dataloader_key, + ) + + def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: + """ + Returns the test [`~torch.utils.data.DataLoader`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + test_dataset (`torch.utils.data.Dataset`, *optional*): + The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. It must implement `__len__`. + """ + return self._get_dataloader( + dataset=test_dataset, + description="test", + batch_size=self.args.eval_batch_size, + sampler_fn=self._get_eval_sampler, + ) + + def create_optimizer_and_scheduler(self, num_training_steps: int): + """ + Setup the optimizer and the learning rate scheduler. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or + `create_scheduler`) in a subclass. + """ + self.create_optimizer() + if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16: + # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer + optimizer = self.optimizer.optimizer + else: + optimizer = self.optimizer + self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) + + def get_decay_parameter_names(self, model) -> list[str]: + """ + Get all parameter names that weight decay will be applied to. + + This function filters out parameters in two ways: + 1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS) + 2. By parameter name patterns (containing 'bias', or variation of 'norm') + """ + forbidden_name_patterns = [r"bias", r"layernorm", r"rmsnorm", r"(?:^|\.)norm(?:$|\.)", r"_norm(?:$|\.)"] + decay_parameters = get_parameter_names(model, [nn.LayerNorm], forbidden_name_patterns) + return decay_parameters + + def create_optimizer(self): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + + if self.optimizer is None: + decay_parameters = self.get_decay_parameter_names(opt_model) + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + + if self.optimizer_cls_and_kwargs is not None: + optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs + else: + optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model) + + # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` + # e.g. for GaLore optimizer. + if "params" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("params") + + # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs` + # e.g. for LOMO optimizer. + if "model" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("model") + + # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` + # to avoid arguments conflicts. + if "optimizer_dict" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict") + + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + + if "bitsandbytes" in str(optimizer_cls) and optimizer_kwargs.get("optim_bits", None) == 8: + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) + logger.info(f"skipped {module}: {skipped / 2**20}M params") + manager.register_module_override(module, "weight", {"optim_bits": 32}) + logger.debug(f"bitsandbytes: will optimize {module} in fp32") + logger.info(f"skipped: {skipped / 2**20}M params") + + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer(self.optimizer) + + return self.optimizer + + def get_num_trainable_parameters(self): + """ + Get the number of trainable parameters. + """ + return sum(p.numel() for p in self.model.parameters() if p.requires_grad) + + def get_learning_rates(self): + """ + Returns the learning rate of each parameter from self.optimizer. + """ + if self.optimizer is None: + raise ValueError("Trainer optimizer is None, please make sure you have setup the optimizer before.") + return [group["lr"] for group in self.optimizer.param_groups] + + def get_optimizer_group(self, param: Optional[Union[str, torch.nn.parameter.Parameter]] = None): + """ + Returns optimizer group for a parameter if given, else returns all optimizer groups for params. + + Args: + param (`str` or `torch.nn.parameter.Parameter`, *optional*): + The parameter for which optimizer group needs to be returned. + """ + if self.optimizer is None: + raise ValueError("Trainer optimizer is None, please make sure you have setup the optimizer before.") + if param is not None: + for group in self.optimizer.param_groups: + if param in group["params"]: + return group + return [group["params"] for group in self.optimizer.param_groups] + + @staticmethod + def get_optimizer_cls_and_kwargs( + args: TrainingArguments, model: Optional[PreTrainedModel] = None + ) -> tuple[Any, Any]: + """ + Returns the optimizer class and optimizer parameters based on the training arguments. + + Args: + args (`transformers.training_args.TrainingArguments`): + The training arguments for the training session. + + """ + + # parse args.optim_args + optim_args = {} + if args.optim_args: + for mapping in args.optim_args.replace(" ", "").split(","): + key, value = mapping.split("=") + optim_args[key] = value + + optimizer_kwargs = {"lr": args.learning_rate} + + adam_kwargs = { + "betas": (args.adam_beta1, args.adam_beta2), + "eps": args.adam_epsilon, + } + + def setup_low_rank_optimizer( + optimizer_name: str, + optimizer_mapping: dict[str, Any], + optim_kwargs: dict[str, Any], + is_layerwise_supported: bool = True, + ) -> tuple[Any, Any]: + """ + Helper function to set up low-rank optimizers like GaLore and Apollo. + + Args: + optimizer_name (str): Name of the optimizer. + optimizer_mapping (dict): Mapping of optimizer names to their classes. + optim_kwargs (dict): Keyword arguments for the optimizer. + is_layerwise_supported (bool): Whether layerwise optimization is supported. + + Returns: + tuple[Any, Any]: Optimizer class and updated optimizer kwargs. + """ + is_layerwise = optimizer_name.lower().endswith("layerwise") + if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED and is_layerwise_supported: + raise NotImplementedError(f"Layer-wise {optimizer_name} does not support DDP at this time") + + optimizer_cls = optimizer_mapping[optimizer_name] + + if args.optim_target_modules is None: + raise ValueError(f"You need to define `optim_target_modules` to use {optimizer_name} optimizers") + + if not isinstance(args.optim_target_modules, (list, str)): + raise TypeError( + f"`optim_target_modules` must be a list of strings, a regex string, or 'all-linear'. Got: {args.optim_target_modules}" + ) + + if model is None: + raise ValueError(f"You need to pass a model to initialize {optimizer_name} optimizer.") + + all_linear = ( + isinstance(args.optim_target_modules, str) + and args.optim_target_modules.replace("_", "-") == "all-linear" + ) + + target_params_names = [] + for module_name, module in model.named_modules(): + target_module_exists, is_regex = check_target_module_exists( + args.optim_target_modules, module_name, return_is_regex=True + ) + + if not isinstance(module, nn.Linear): + if target_module_exists and not is_regex: + logger.warning( + f"{module_name} matched but ignored. {optimizer_name} only supports linear layers." + ) + continue + + if not target_module_exists and not all_linear: + continue + + target_params_names.append(module_name + ".weight") + + if len(target_params_names) == 0: + raise ValueError(f"No target modules found for {optimizer_name} ({args.optim_target_modules}).") + + target_params = [p for n, p in model.named_parameters() if n in target_params_names] + non_target_params = [p for n, p in model.named_parameters() if n not in target_params_names] + optim_kwargs.update(optim_args) + + param_groups = [ + {"params": non_target_params}, + {"params": target_params, **optim_kwargs}, + ] + + if is_layerwise: + if args.gradient_accumulation_steps != 1: + raise ValueError(f"Layerwise {optimizer_name} does not support gradient accumulation!") + + optimizer_dict = {} + for param in non_target_params: + optimizer_dict[param] = optimizer_cls([{"params": [param]}], **optimizer_kwargs) + for param in target_params: + optimizer_dict[param] = optimizer_cls([{"params": [param], **optim_kwargs}], **optimizer_kwargs) + + def optimizer_hook(param): + if param.grad is not None: + optimizer_dict[param].step() + optimizer_dict[param].zero_grad() + + for param in model.parameters(): + if param.requires_grad: + param.register_post_accumulate_grad_hook(optimizer_hook) + + optimizer_cls = LayerWiseDummyOptimizer + optimizer_kwargs.update({"optimizer_dict": optimizer_dict}) + + optimizer_kwargs.update({"params": param_groups}) + return optimizer_cls, optimizer_kwargs + + if args.optim == OptimizerNames.ADAFACTOR: + optimizer_cls = Adafactor + optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) + elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]: + from torch.optim import AdamW + + optimizer_cls = AdamW + optimizer_kwargs.update(adam_kwargs) + if args.optim == OptimizerNames.ADAMW_TORCH_FUSED: + optimizer_kwargs.update({"fused": True}) + elif args.optim == OptimizerNames.ADAMW_TORCH_XLA: + try: + from torch_xla.amp.syncfree import AdamW + + optimizer_cls = AdamW + optimizer_kwargs.update(adam_kwargs) + except ImportError: + raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.") + elif args.optim == OptimizerNames.ADAMW_TORCH_NPU_FUSED: + try: + from torch_npu.optim import NpuFusedAdamW + + optimizer_cls = NpuFusedAdamW + optimizer_kwargs.update(adam_kwargs) + except ImportError: + raise ValueError("Trainer failed to import FusedAdamW from torch_npu.") + elif args.optim == OptimizerNames.ADAMW_APEX_FUSED: + try: + from apex.optimizers import FusedAdam + + optimizer_cls = FusedAdam + optimizer_kwargs.update(adam_kwargs) + except ImportError: + raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!") + elif args.optim in [ + OptimizerNames.ADAMW_BNB, + OptimizerNames.ADAMW_8BIT, + OptimizerNames.PAGED_ADAMW, + OptimizerNames.PAGED_ADAMW_8BIT, + OptimizerNames.ADEMAMIX, + OptimizerNames.ADEMAMIX_8BIT, + OptimizerNames.PAGED_ADEMAMIX, + OptimizerNames.PAGED_ADEMAMIX_8BIT, + OptimizerNames.LION, + OptimizerNames.LION_8BIT, + OptimizerNames.PAGED_LION, + OptimizerNames.PAGED_LION_8BIT, + OptimizerNames.RMSPROP_BNB, + OptimizerNames.RMSPROP_8BIT, + OptimizerNames.RMSPROP_32BIT, + ]: + try: + from bitsandbytes.optim import AdamW, Lion, RMSprop + + is_paged = False + optim_bits = 32 + optimizer_cls = None + additional_optim_kwargs = adam_kwargs + if "paged" in args.optim: + is_paged = True + if "8bit" in args.optim: + optim_bits = 8 + if "adam" in args.optim: + optimizer_cls = AdamW + elif "lion" in args.optim: + optimizer_cls = Lion + additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)} + elif "rmsprop" in args.optim: + optimizer_cls = RMSprop + # Above we pass all `adam_kwargs` to the optimizer, here + # we only pass `optim_args` which can be passed by the user. + additional_optim_kwargs = optim_args + elif "ademamix" in args.optim: + if is_bitsandbytes_available() and version.parse( + importlib.metadata.version("bitsandbytes") + ) < version.parse("0.44.0"): + raise ValueError( + "The AdEMAMix optimizer is not supported by your current version of `bitsandbytes`. " + "Please install `bitsandbytes` >= 0.44.0." + ) + + from bitsandbytes.optim import AdEMAMix + + optimizer_cls = AdEMAMix + additional_optim_kwargs = { + "betas": ( + float(optim_args.get("beta1", args.adam_beta1)), + float(optim_args.get("beta2", args.adam_beta2)), + float(optim_args.get("beta3", 0.9999)), + ), + "alpha": float(optim_args.get("alpha", 5.0)), + "eps": float(optim_args.get("eps", args.adam_epsilon)), + } + + if "t_alpha" in optim_args: + additional_optim_kwargs["t_alpha"] = int(optim_args["t_alpha"]) + + if "t_beta3" in optim_args: + additional_optim_kwargs["t_beta3"] = int(optim_args["t_beta3"]) + + bnb_kwargs = {"optim_bits": optim_bits} + if "rmsprop" not in args.optim: + bnb_kwargs["is_paged"] = is_paged + + optimizer_kwargs.update(additional_optim_kwargs) + optimizer_kwargs.update(bnb_kwargs) + except ImportError: + raise ValueError("Trainer tried to instantiate bnb optimizer but `bitsandbytes` is not installed!") + if is_bitsandbytes_available() and version.parse( + importlib.metadata.version("bitsandbytes") + ) < version.parse("0.41.1"): + logger.warning( + "You are using 8-bit optimizers with a version of `bitsandbytes` < 0.41.1. " + "It is recommended to update your version as a major bug has been fixed in 8-bit optimizers." + ) + elif args.optim == OptimizerNames.ADAMW_ANYPRECISION: + try: + from torchdistx.optimizers import AnyPrecisionAdamW + + optimizer_cls = AnyPrecisionAdamW + optimizer_kwargs.update(adam_kwargs) + + # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx. + optimizer_kwargs.update( + { + "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")), + "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")), + "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")), + "compensation_buffer_dtype": getattr( + torch, optim_args.get("compensation_buffer_dtype", "bfloat16") + ), + } + ) + except ImportError: + raise ValueError("Please install https://github.com/pytorch/torchdistx") + elif args.optim == OptimizerNames.SGD: + optimizer_cls = torch.optim.SGD + elif args.optim == OptimizerNames.ADAGRAD: + optimizer_cls = torch.optim.Adagrad + elif args.optim == OptimizerNames.RMSPROP: + optimizer_cls = torch.optim.RMSprop + elif args.optim in [ + OptimizerNames.GALORE_ADAMW, + OptimizerNames.GALORE_ADAMW_8BIT, + OptimizerNames.GALORE_ADAFACTOR, + OptimizerNames.GALORE_ADAMW_LAYERWISE, + OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE, + OptimizerNames.GALORE_ADAFACTOR_LAYERWISE, + ]: + if not is_galore_torch_available(): + raise ImportError( + "You need to install `galore_torch` in order to use GaLore optimizers" + " install it with `pip install git+https://github.com/jiaweizzhao/GaLore`" + ) + from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit + + optimizer_mapping = { + OptimizerNames.GALORE_ADAMW: GaLoreAdamW, + OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit, + OptimizerNames.GALORE_ADAFACTOR: GaLoreAdafactor, + OptimizerNames.GALORE_ADAMW_LAYERWISE: GaLoreAdamW, + OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE: GaLoreAdamW8bit, + OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor, + } + + galore_optim_kwargs = { + "rank": int(optim_args.pop("rank", 128)), + "update_proj_gap": int(optim_args.pop("update_proj_gap", 200)), + "scale": float(optim_args.pop("scale", 0.25)), + "proj_type": optim_args.pop("proj_type", "std"), + } + + optimizer_cls, optimizer_kwargs = setup_low_rank_optimizer( + args.optim, optimizer_mapping, galore_optim_kwargs + ) + if args.optim == OptimizerNames.GALORE_ADAFACTOR: + optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) + elif args.optim in [ + OptimizerNames.APOLLO_ADAMW, + OptimizerNames.APOLLO_ADAMW_LAYERWISE, + ]: + if not is_apollo_torch_available(): + raise ImportError( + "You need to install `apollo_torch` in order to use APOLLO optimizers" + " install it with `pip install git+https://github.com/zhuhanqing/APOLLO`" + ) + from apollo_torch import APOLLOAdamW + + optimizer_mapping = { + OptimizerNames.APOLLO_ADAMW: APOLLOAdamW, + OptimizerNames.APOLLO_ADAMW_LAYERWISE: APOLLOAdamW, + } + + apollo_optim_kwargs = { + "rank": int(optim_args.pop("rank", 128)), + "proj": optim_args.pop("proj", "random"), + "scale_type": optim_args.pop("scale_type", "channel"), + "update_proj_gap": int(optim_args.pop("update_proj_gap", 200)), + "scale": float(optim_args.pop("scale", 1.0)), + "proj_type": optim_args.pop("proj_type", "std"), + } + apollo_optim_kwargs.update(adam_kwargs) + + optimizer_cls, optimizer_kwargs = setup_low_rank_optimizer( + args.optim, optimizer_mapping, apollo_optim_kwargs + ) + elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + if not is_lomo_available(): + raise ImportError( + "You need to install `lomo_optim` in order to use LOMO optimizers" + " install it with `pip install lomo-optim`" + ) + if not is_accelerate_available("0.30.0"): + raise ImportError("You need to have `accelerate>=0.30.0` to be able to use LOMO optimizers") + + if model is None: + raise ValueError("You need to pass a `model` in order to correctly initialize a LOMO optimizer.") + + from lomo_optim import AdaLomo, Lomo + + if "ada" in args.optim: + optimizer_cls = AdaLomo + else: + optimizer_cls = Lomo + + optimizer_kwargs.update({"model": model}) + elif args.optim == OptimizerNames.GROKADAMW: + if not is_grokadamw_available(): + raise ValueError("Please install grokadamw with `pip install grokadamw`") + + from grokadamw import GrokAdamW + + optimizer_cls = GrokAdamW + optimizer_kwargs.update( + { + "alpha_init": float(optim_args.get("alpha_init", 0.98)), + "lamb": float(optim_args.get("lamb", 2.0)), + "gamma": float(optim_args.get("gamma", 0.1)), + "grokking_signal_decay_rate": float(optim_args.get("grokking_signal_decay_rate", 0.1)), + "gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)), + } + ) + elif args.optim in [ + OptimizerNames.ADAMW_TORCH_4BIT, + OptimizerNames.ADAMW_TORCH_8BIT, + ]: + if not is_torchao_available() or version.parse(importlib.metadata.version("torchao")) < version.parse( + "0.4.0" + ): + raise ImportError( + "You need to have `torchao>=0.4.0` in order to use torch 4-bit optimizers." + "Install it with `pip install torchao` or follow the instructions here: https://github.com/pytorch/ao" + ) + if version.parse(importlib.metadata.version("torch")) <= version.parse("2.4"): + raise ImportError( + "You need to have `torch>2.4` in order to use torch 4-bit optimizers. " + "Install it with `pip install --upgrade torch` it is available on pipy. Otherwise, you need to install torch nightly." + ) + if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.11.0"): + # https://github.com/pytorch/ao/pull/2159 + from torchao.optim import AdamW4bit, AdamW8bit + else: + from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit + if args.optim == OptimizerNames.ADAMW_TORCH_4BIT: + optimizer_cls = AdamW4bit + elif args.optim == OptimizerNames.ADAMW_TORCH_8BIT: + optimizer_cls = AdamW8bit + else: + raise ValueError("Invalid optimizer") + optimizer_kwargs.update(adam_kwargs) + elif args.optim in [ + OptimizerNames.SCHEDULE_FREE_RADAM, + OptimizerNames.SCHEDULE_FREE_ADAMW, + OptimizerNames.SCHEDULE_FREE_SGD, + ]: + if not is_schedulefree_available(): + raise ImportError( + "You need to install `schedulefree` in order to use schedulefree optimizers. " + "Install it with `pip install schedulefree.`" + ) + if not is_accelerate_available("0.30.0"): + raise ImportError("You need to have `accelerate>=0.30.0` to be able to use schedulefree optimizers") + from schedulefree import AdamWScheduleFree, SGDScheduleFree + + additional_optim_kwargs = {} + require_warmup = True + + if args.optim == OptimizerNames.SCHEDULE_FREE_RADAM: + if not is_schedulefree_available("1.4.0"): + raise ImportError( + "You need to install `schedulefree>=1.4.0` in order to use RAdamScheduleFree optimizer. " + "Install it with `pip install schedulefree.`" + ) + from schedulefree import RAdamScheduleFree + + optimizer_cls = RAdamScheduleFree + additional_optim_kwargs = adam_kwargs + require_warmup = False + elif args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW: + optimizer_cls = AdamWScheduleFree + additional_optim_kwargs = adam_kwargs + elif args.optim == OptimizerNames.SCHEDULE_FREE_SGD: + optimizer_cls = SGDScheduleFree + else: + raise ValueError("Invalid schedulefree optimizer") + + additional_optim_kwargs["weight_decay"] = args.weight_decay + if require_warmup: + additional_optim_kwargs["warmup_steps"] = args.warmup_steps + additional_optim_kwargs.update( + { + "weight_lr_power": float(optim_args.get("weight_lr_power", 2.0)), + "r": float(optim_args.get("r", 0.0)), + } + ) + optimizer_kwargs.update(additional_optim_kwargs) + elif args.optim == OptimizerNames.STABLE_ADAMW: + if not is_torch_optimi_available(): + raise ImportError( + "You need to install `torch-optimi` in order to use stable_adamw optimizers. " + "Install it with `pip install torch-optimi`." + ) + from optimi import StableAdamW + + max_lr = optim_args.pop("max_lr", None) + if max_lr is not None: + max_lr = float(max_lr) + + kahan_sum = optim_args.pop("kahan_sum", None) + if kahan_sum is not None: + kahan_sum = bool(kahan_sum) + + adam_kwargs["weight_decay"] = args.weight_decay + stable_adamw_kwargs = { + "decouple_lr": bool(optim_args.pop("decouple_lr", False)), + "max_lr": max_lr, + "kahan_sum": kahan_sum, + } + + optimizer_cls = StableAdamW + optimizer_kwargs.update(adam_kwargs) + optimizer_kwargs.update(stable_adamw_kwargs) + else: + raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") + return optimizer_cls, optimizer_kwargs + + def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): + """ + Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or + passed as an argument. + + Args: + num_training_steps (int): The number of training steps to do. + """ + if self.lr_scheduler is None: + self.lr_scheduler = get_scheduler( + self.args.lr_scheduler_type, + optimizer=self.optimizer if optimizer is None else optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + scheduler_specific_kwargs=self.args.lr_scheduler_kwargs, + ) + self._created_lr_scheduler = True + return self.lr_scheduler + + def num_examples(self, dataloader: DataLoader) -> int: + """ + Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When + dataloader.dataset does not exist or has no length, estimates as best it can + """ + try: + dataset = dataloader.dataset + # Special case for IterableDatasetShard, we need to dig deeper + if isinstance(dataset, IterableDatasetShard): + return len(dataloader.dataset.dataset) + return len(dataloader.dataset) + except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader + return len(dataloader) * self.args.per_device_train_batch_size + + @staticmethod + def num_tokens(train_dl: DataLoader, max_steps: Optional[int] = None) -> int: + """ + Helper to get number of tokens in a [`~torch.utils.data.DataLoader`] by enumerating dataloader. + """ + train_tokens = 0 + try: + for batch in train_dl: + tokens = batch["input_ids"].numel() + if max_steps is not None: + return tokens * max_steps + train_tokens += tokens + except KeyError: + logger.warning("Cannot get num_tokens from dataloader") + return train_tokens + + def _hp_search_setup(self, trial: Union["optuna.Trial", dict[str, Any]]): + """HP search setup code""" + self._trial = trial + + if self.hp_search_backend is None or trial is None: + return + if self.hp_search_backend == HPSearchBackend.OPTUNA: + params = self.hp_space(trial) + elif self.hp_search_backend == HPSearchBackend.RAY: + params = trial + params.pop("wandb", None) + elif self.hp_search_backend == HPSearchBackend.SIGOPT: + params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()} + elif self.hp_search_backend == HPSearchBackend.WANDB: + params = trial + + for key, value in params.items(): + if not hasattr(self.args, key): + logger.warning( + f"Trying to set {key} in the hyperparameter search but there is no corresponding field in" + " `TrainingArguments`." + ) + continue + old_attr = getattr(self.args, key, None) + # Casting value to the proper type + if old_attr is not None: + value = type(old_attr)(value) + + setattr(self.args, key, value) + if self.hp_search_backend == HPSearchBackend.OPTUNA: + logger.info(f"Trial: {trial.params}") + if self.hp_search_backend == HPSearchBackend.SIGOPT: + logger.info(f"SigOpt Assignments: {trial.assignments}") + if self.hp_search_backend == HPSearchBackend.WANDB: + logger.info(f"W&B Sweep parameters: {trial}") + if self.is_deepspeed_enabled: + if self.args.deepspeed is None: + raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set") + + self.accelerator.free_memory() + + # Rebuild the deepspeed config to reflect the updated training parameters + from accelerate.utils import DeepSpeedPlugin + + from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig + + self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed) + self.args.hf_deepspeed_config.trainer_config_process(self.args) + self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config) + + # From 1.0 on, we need to fully wipe the DS plugin when doing sweeps. + # Simply calling `_reset_state` is enough and doesn't need a version pin. + AcceleratorState()._reset_state() + + self.create_accelerator_and_postprocess() + + def _report_to_hp_search(self, trial: Union["optuna.Trial", dict[str, Any]], step: int, metrics: dict[str, float]): + if self.hp_search_backend is None or trial is None: + return + metrics = metrics.copy() + self.objective = self.compute_objective(metrics) + if self.hp_search_backend == HPSearchBackend.OPTUNA: + import optuna + + if hasattr(trial, "study") and not trial.study._is_multi_objective(): + trial.report(self.objective, step) + if trial.should_prune(): + self.callback_handler.on_train_end(self.args, self.state, self.control) + raise optuna.TrialPruned() + elif self.hp_search_backend == HPSearchBackend.RAY: + import ray.train + + with tempfile.TemporaryDirectory() as temp_checkpoint_dir: + checkpoint = None + if self.control.should_save: + self._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir) + checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir) + metrics["objective"] = self.objective + ray.train.report(metrics, checkpoint=checkpoint) + + def _tune_save_checkpoint(self, checkpoint_dir: str): + output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") + self.save_model(output_dir, _internal_call=True) + if self.args.should_save: + # Update the `TrainerControl` state to where we are currently + self.state.stateful_callbacks["TrainerControl"] = self.control.state() + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + + def call_model_init(self, trial=None): + model_init_argcount = number_of_arguments(self.model_init) + if model_init_argcount == 0: + model = self.model_init() + elif model_init_argcount == 1: + model = self.model_init(trial) + else: + raise RuntimeError("model_init should have 0 or 1 argument.") + + if model is None: + raise RuntimeError("model_init should not return None.") + + return model + + def torch_jit_model_eval(self, model, dataloader, training=False): + if not training: + if dataloader is None: + logger.warning("failed to use PyTorch jit mode due to current dataloader is none.") + return model + example_batch = next(iter(dataloader)) + example_batch = self._prepare_inputs(example_batch) + try: + jit_model = copy.copy(model) + jit_model.eval() + original_forward = jit_model.__dict__.pop("_original_forward", None) + # remove mixed precision hooks from the model + if original_forward: + jit_model.forward = original_forward + autocast_handler = AutocastKwargs(cache_enabled=False) + with self.accelerator.autocast(autocast_handler=autocast_handler), torch.no_grad(): + if isinstance(example_batch, dict): + jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False) + else: + jit_model = torch.jit.trace( + jit_model, + example_kwarg_inputs={key: example_batch[key] for key in example_batch}, + strict=False, + ) + jit_model = torch.jit.freeze(jit_model) + with torch.no_grad(): + jit_model(**example_batch) + jit_model(**example_batch) + model = jit_model + self.use_cpu_amp = False + except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e: + logger.warning(f"failed to use PyTorch jit mode due to: {e}.") + + return model + + def compare_trainer_and_checkpoint_args(self, training_args, trainer_state): + attributes_map = { + "logging_steps": "logging_steps", + "eval_steps": "eval_steps", + "save_steps": "save_steps", + } + + has_warning = False + warning_str = "Warning: The following arguments do not match the ones in the `trainer_state.json` within the checkpoint directory: " + for arg_attr, state_attr in attributes_map.items(): + arg_value = getattr(training_args, arg_attr, None) + state_value = getattr(trainer_state, state_attr, None) + + if arg_value is not None and state_value is not None and arg_value != state_value: + warning_str += f"\n\t{arg_attr}: {arg_value} (from args) != {state_value} (from trainer_state.json)" + has_warning = True + + # train bs is special as we need to account for multi-GPU + train_bs_args = training_args.per_device_train_batch_size + train_bs_state = trainer_state.train_batch_size // max(1, training_args.n_gpu) + + if train_bs_args != train_bs_state: + warning_str += f"\n\tper_device_train_batch_size: {train_bs_args} (from args) != {train_bs_state} (from trainer_state.json)" + has_warning = True + + if has_warning: + logger.warning_once(warning_str) + + def _wrap_model(self, model, training=True, dataloader=None): + if is_sagemaker_mp_enabled(): + # Wrapping the base model twice in a DistributedModel will raise an error. + if isinstance(self.model_wrapped, smp.model.DistributedModel): + return self.model_wrapped + return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) + + # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again + if self.accelerator.unwrap_model(model, keep_torch_compile=False) is not model: + return model + + # Mixed precision training with apex + if self.use_apex and training: + from apex import amp + + model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) + + # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP + if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False): + model = nn.DataParallel(model) + + if self.args.jit_mode_eval: + start_time = time.time() + model = self.torch_jit_model_eval(model, dataloader, training) + self.jit_compilation_time = round(time.time() - start_time, 4) + + # Note: in torch.distributed mode, there's no point in wrapping the model + # inside a DistributedDataParallel as we'll be under `no_grad` anyways. + if not training: + return model + + # Distributed training (should be after apex fp16 initialization) + # Distributed training using PyTorch FSDP + if self.is_fsdp_xla_enabled: + try: + from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP + from torch_xla.distributed.fsdp import checkpoint_module + from torch_xla.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + transformer_auto_wrap_policy, + ) + + if self.is_fsdp_xla_v2_enabled: + from torch_xla.experimental.spmd_fully_sharded_data_parallel import ( + SpmdFullyShardedDataParallel as FSDPv2, + ) + except ImportError: + raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") + auto_wrap_policy = None + auto_wrapper_callable = None + default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None) + fsdp_transformer_layer_cls_to_wrap = self.args.fsdp_config.get( + "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap + ) + + if self.args.fsdp_config["min_num_params"] > 0: + auto_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["min_num_params"] + ) + elif fsdp_transformer_layer_cls_to_wrap is not None: + transformer_cls_to_wrap = set() + for layer_class in fsdp_transformer_layer_cls_to_wrap: + transformer_cls = get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise Exception("Could not find the transformer layer class to wrap in the model.") + else: + transformer_cls_to_wrap.add(transformer_cls) + + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + # Transformer layer class to wrap + transformer_layer_cls=transformer_cls_to_wrap, + ) + fsdp_kwargs = self.args.xla_fsdp_config + if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: + if model.config.use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + model.config.use_cache = False + + # Apply gradient checkpointing to auto-wrapped sub-modules if specified + def auto_wrapper_callable(m, *args, **kwargs): + target_cls = FSDP if not self.is_fsdp_xla_v2_enabled else FSDPv2 + return target_cls(checkpoint_module(m), *args, **kwargs) + + # Wrap the base model with an outer FSDP wrapper + if self.is_fsdp_xla_v2_enabled: + + def shard_output(output, mesh): + from .modeling_outputs import CausalLMOutputWithPast + + real_output = None + if isinstance(output, torch.Tensor): + real_output = output + elif isinstance(output, tuple): + real_output = output[0] + elif isinstance(output, CausalLMOutputWithPast): + real_output = output.logits + + if real_output is None: + raise ValueError("Something went wrong, the output of the model shouldn't be `None`") + xs.mark_sharding(real_output, mesh, ("fsdp", None, None)) + + self.model = model = FSDPv2( + model, + shard_output=shard_output, + auto_wrap_policy=auto_wrap_policy, + auto_wrapper_callable=auto_wrapper_callable, + ) + else: + self.model = model = FSDP( + model, + auto_wrap_policy=auto_wrap_policy, + auto_wrapper_callable=auto_wrapper_callable, + **fsdp_kwargs, + ) + + # Patch `xm.optimizer_step` should not reduce gradients in this case, + # as FSDP does not need gradient reduction over sharded parameters. + def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): + loss = optimizer.step(**optimizer_args) + if barrier: + xm.mark_step() + return loss + + xm.optimizer_step = patched_optimizer_step + elif is_sagemaker_dp_enabled(): + model = nn.parallel.DistributedDataParallel( + model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] + ) + elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: + if is_torch_neuroncore_available(): + return model + kwargs = {} + if self.args.ddp_find_unused_parameters is not None: + kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters + elif isinstance(model, PreTrainedModel): + # find_unused_parameters breaks checkpointing as per + # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 + kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing + else: + kwargs["find_unused_parameters"] = True + + if self.args.ddp_bucket_cap_mb is not None: + kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb + + if self.args.ddp_broadcast_buffers is not None: + kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers + + self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) + + return model + + def train( + self, + resume_from_checkpoint: Optional[Union[str, bool]] = None, + trial: Union["optuna.Trial", dict[str, Any], None] = None, + ignore_keys_for_eval: Optional[list[str]] = None, + **kwargs: Any, + ): + """ + Main training entry point. + + Args: + resume_from_checkpoint (`str` or `bool`, *optional*): + If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a + `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance + of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here. + trial (`optuna.Trial` or `dict[str, Any]`, *optional*): + The trial run or the hyperparameter dictionary for hyperparameter search. + ignore_keys_for_eval (`list[str]`, *optional*) + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions for evaluation during the training. + kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments used to hide deprecated arguments + """ + if resume_from_checkpoint is False: + resume_from_checkpoint = None + + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + args = self.args + + self.is_in_train = True + + # If the model uses a tokenizer, it may have a new tokens for fine-tuning purposes. + if isinstance(self.processing_class, (PreTrainedTokenizerBase, ProcessorMixin)) and hasattr( + self.model, "config" + ): + self._align_special_tokens() + + # Attach NEFTune hooks if necessary + if self.neftune_noise_alpha is not None: + self.model = self._activate_neftune(self.model) + + # do_train is not a reliable argument, as it might not be set and .train() still called, so + # the following is a workaround: + if ( + (args.fp16_full_eval or args.bf16_full_eval) + and not args.do_train + and not self.is_model_parallel + and self.model_init is None + ): + self._move_model_to_device(self.model, args.device) + + if "model_path" in kwargs: + resume_from_checkpoint = kwargs.pop("model_path") + warnings.warn( + "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` " + "instead.", + FutureWarning, + ) + if len(kwargs) > 0: + raise TypeError(f"train() got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") + # This might change the seed so needs to run first. + self._hp_search_setup(trial) + self._train_batch_size = self.args.train_batch_size + + # Model re-init + model_reloaded = False + if self.model_init is not None: + # Seed must be set before instantiating the model when using model_init. + enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) + self.model = self.call_model_init(trial) + model_reloaded = True + # Reinitializes optimizer and scheduler + self.optimizer, self.lr_scheduler = None, None + + # Load potential model checkpoint + if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: + resume_from_checkpoint = get_last_checkpoint(args.output_dir) + if resume_from_checkpoint is None: + raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") + + if resume_from_checkpoint is not None: + if not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled and not self.is_fsdp_enabled: + self._load_from_checkpoint(resume_from_checkpoint) + # In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly + state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + if state.train_batch_size is not None: + self._train_batch_size = state.train_batch_size + + # If model was re-initialized, put it on the right device and update self.model_wrapped + if model_reloaded: + if self.place_model_on_device: + self._move_model_to_device(self.model, args.device) + self.model_wrapped = self.model + + inner_training_loop = find_executable_batch_size( + self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size + ) + if args.push_to_hub: + try: + # Disable progress bars when uploading models during checkpoints to avoid polluting stdout + hf_hub_utils.disable_progress_bars() + return inner_training_loop( + args=args, + resume_from_checkpoint=resume_from_checkpoint, + trial=trial, + ignore_keys_for_eval=ignore_keys_for_eval, + ) + finally: + hf_hub_utils.enable_progress_bars() + else: + return inner_training_loop( + args=args, + resume_from_checkpoint=resume_from_checkpoint, + trial=trial, + ignore_keys_for_eval=ignore_keys_for_eval, + ) + + def get_tp_size(self) -> int: + """Get the tensor parallel size from either the model or DeepSpeed config.""" + + # 1. Check model.tp_size first + if (model_tp := getattr(self.model, "_tp_size", None)) is not None: + return model_tp + + # 2. Fall back to DeepSpeed config if enabled + if self.is_deepspeed_enabled and (deepspeed_config := getattr(self.args, "hf_deepspeed_config", None)): + return deepspeed_config.config.get("tensor_parallel", {}).get("autotp_size", 1) + + # 3. Default fallback + return 1 + + def get_total_train_batch_size(self, args) -> int: + """Calculates total batch size (micro_batch * grad_accum * dp_world_size). + + Note: Only considers DP and TP (dp_world_size = world_size // tp_size).""" + dp_world_size = args.world_size // self.get_tp_size() + return self._train_batch_size * args.gradient_accumulation_steps * dp_world_size + + def _inner_training_loop( + self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None + ): + self.accelerator.free_memory() + self._train_batch_size = batch_size + if self.args.auto_find_batch_size: + if self.state.train_batch_size != self._train_batch_size: + from accelerate.utils import release_memory + + (self.model_wrapped,) = release_memory(self.model_wrapped) + self.model_wrapped = self.model + + # Check for DeepSpeed *after* the initial pass and modify the config + if self.is_deepspeed_enabled: + # Temporarily unset `self.args.train_batch_size` + original_bs = self.args.per_device_train_batch_size + self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu) + self.propagate_args_to_deepspeed(True) + self.args.per_device_train_batch_size = original_bs + self.state.train_batch_size = self._train_batch_size + logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") + # Data loader and number of training steps + train_dataloader = self.get_train_dataloader() + if self.is_fsdp_xla_v2_enabled: + train_dataloader = tpu_spmd_dataloader(train_dataloader) + + # Setting up training control variables: + # number of training epochs: num_train_epochs + # number of training steps per epoch: num_update_steps_per_epoch + # total number of training steps to execute: max_steps + total_train_batch_size = self.get_total_train_batch_size(args) + + ( + num_train_epochs, + num_update_steps_per_epoch, + num_examples, + num_train_samples, + epoch_based, + len_dataloader, + max_steps, + ) = self.set_initial_training_values(args, train_dataloader, total_train_batch_size) + + num_train_tokens = None + if self.args.include_tokens_per_second: + num_train_tokens = self.num_tokens(train_dataloader, None if epoch_based else max_steps) + # If going by epochs, multiply tokens linearly + if len_dataloader is not None and epoch_based: + num_train_tokens *= args.num_train_epochs + # Otherwise since its steps, we just multiply by grad accum + else: + num_train_tokens *= args.gradient_accumulation_steps + + if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: + if self.args.n_gpu > 1: + # nn.DataParallel(model) replicates the model, creating new variables and module + # references registered here no longer work on other gpus, breaking the module + raise ValueError( + "Currently --debug underflow_overflow is not supported under DP. Please use DDP" + " (torchrun or torch.distributed.launch (deprecated))." + ) + else: + DebugUnderflowOverflow(self.model) + + delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled + + # Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404 + is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2) + if is_fsdp2: + delay_optimizer_creation = False + + # We need to reset the scheduler, as its parameters may be different on subsequent calls + if self._created_lr_scheduler: + self.lr_scheduler = None + self._created_lr_scheduler = False + + if self.is_deepspeed_enabled: + self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) + + if not delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + self.state = TrainerState( + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ] + ) + self.state.is_hyper_param_search = trial is not None + self.state.train_batch_size = self._train_batch_size + + # Compute absolute values for logging, eval, and save if given as ratio + self.state.compute_steps(args, max_steps) + + # Activate gradient checkpointing if needed + if args.gradient_checkpointing: + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs) + + model = self._wrap_model(self.model_wrapped) + + # as the model is wrapped, don't use `accelerator.prepare` + # this is for unhandled cases such as + # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX + use_accelerator_prepare = model is self.model + + if use_accelerator_prepare and self.is_fsdp_enabled: + # In case of auto_find_batch_size=True + # Remove FSDP wrapping from sub-models. + self.model = unwrap_model(self.model, recursive=True) + + if delay_optimizer_creation: + if use_accelerator_prepare: + # configure fsdp plugin for qlora if any + self._fsdp_qlora_plugin_updates() + if self.accelerator.mixed_precision != "fp8": + self.model = self.accelerator.prepare(self.model) + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + # prepare using `accelerator` prepare + if use_accelerator_prepare: + self.model.train() + if hasattr(self.lr_scheduler, "step"): + if self.use_apex: + model = self.accelerator.prepare(self.model) + else: + # We should avoid accelerate preparing the model in TP case since we dont need it as it is handled by transformers from_pretrained and also it goes into DDP based preparation. + if self.is_tp_enabled: + self.optimizer = self.accelerator.prepare(self.optimizer) + else: + model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + else: + # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. + model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + else: + self.optimizer = self.accelerator.prepare(self.optimizer) + + if self.is_fsdp_enabled: + self.model = self.model_wrapped = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # ckpt loading + if resume_from_checkpoint is not None: + if self.is_deepspeed_enabled: + deepspeed_load_checkpoint( + self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model) + ) + elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: + self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) + + # Check if saved optimizer or scheduler states exist + self._load_optimizer_and_scheduler(resume_from_checkpoint) + self._load_scaler(resume_from_checkpoint) + + # important: at this point: + # self.model is the Transformers Model + # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), + # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. + + # Train! + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_examples:,}") + logger.info(f" Num Epochs = {num_train_epochs:,}") + logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") + if self.args.per_device_train_batch_size != self._train_batch_size: + logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps:,}") + logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") + + self.state.epoch = 0 + start_time = time.time() + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + + # Check if continuing training from a checkpoint + if resume_from_checkpoint is not None and os.path.isfile( + os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) + ): + self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + self.compare_trainer_and_checkpoint_args(self.args, self.state) + self._load_callback_state() + epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) + if not args.ignore_data_skip: + steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) + steps_trained_in_current_epoch *= args.gradient_accumulation_steps + else: + steps_trained_in_current_epoch = 0 + + logger.info(" Continuing training from checkpoint, will skip to saved global_step") + logger.info(f" Continuing training from epoch {epochs_trained}") + logger.info(f" Continuing training from global step {self.state.global_step}") + if not args.ignore_data_skip: + logger.info( + f" Will skip the first {epochs_trained} epochs then the first" + f" {steps_trained_in_current_epoch} batches in the first epoch." + ) + + # Update the references + for attr in ("model", "optimizer", "lr_scheduler"): + setattr(self.callback_handler, attr, getattr(self, attr)) + self.callback_handler.train_dataloader = train_dataloader + + self.state.init_training_references(self, max_steps, num_train_epochs, trial) + + # tr_loss is a tensor to avoid synchronization of TPUs through .item() + tr_loss = torch.tensor(0.0, device=args.device) + # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses + self._total_loss_scalar = 0.0 + self._globalstep_last_logged = self.state.global_step + model.zero_grad() + grad_norm: Optional[float] = None + learning_rate = None + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + + if args.eval_on_start: + self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) + + for epoch in range(epochs_trained, num_train_epochs): + epoch_dataloader = train_dataloader + if hasattr(epoch_dataloader, "set_epoch"): + epoch_dataloader.set_epoch(epoch) + + # Reset the past mems state at the beginning of each epoch if necessary. + if args.past_index >= 0: + self._past = None + + steps_in_epoch = ( + len(epoch_dataloader) + if len_dataloader is not None + else args.max_steps * args.gradient_accumulation_steps + ) + self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) + + step = -1 + rng_to_sync = False + + # Handle resumption from checkpoint + if epoch == epochs_trained and resume_from_checkpoint is not None: + if steps_trained_in_current_epoch > 0 and not args.ignore_data_skip: + epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) + step = steps_trained_in_current_epoch - 1 + rng_to_sync = True + elif steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + + epoch_iterator = iter(epoch_dataloader) + # We chunkify the epoch iterator into gradient accumulation steps `n` batches + remainder = steps_in_epoch % args.gradient_accumulation_steps + if remainder == 0: + remainder = args.gradient_accumulation_steps + update_step = -1 + total_updates = steps_in_epoch // args.gradient_accumulation_steps + int( + remainder < args.gradient_accumulation_steps + ) + for _ in range(total_updates): + update_step += 1 + num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder + batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device) + # Store the number of batches for current gradient accumulation + # This is used to correctly scale the loss when the last accumulation step has fewer batches + self.current_gradient_accumulation_steps = len(batch_samples) + for i, inputs in enumerate(batch_samples): + step += 1 + do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch + # Since we perform prefetching, we need to manually set sync_gradients + self.accelerator.gradient_state._set_sync_gradients(do_sync_step) + + if self.args.include_num_input_tokens_seen not in ["no", False]: + main_input_name = getattr(self.model, "main_input_name", "input_ids") + if main_input_name not in inputs: + logger.warning( + "Tried to track the number of tokens seen, however the current model is " + "not configured properly to know what item is the input. To fix this, add " + "a `main_input_name` attribute to the model class you are using." + ) + else: + if self.args.include_num_input_tokens_seen == "non_padding": + if "attention_mask" in inputs: + input_tokens = inputs["attention_mask"].sum() + elif ( + self.processing_class is not None + and hasattr(self.processing_class, "pad_token_id") + and self.processing_class.pad_token_id is not None + ): + input_tokens = ( + inputs[main_input_name] != self.processing_class.pad_token_id + ).sum() + else: + logger.warning( + "Could not determine method to count non-padding tokens, falling back to counting all tokens." + ) + input_tokens = inputs[main_input_name].numel() + else: + input_tokens = inputs[main_input_name].numel() + + input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) + self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item() + + if rng_to_sync: + self._load_rng_state(resume_from_checkpoint) + rng_to_sync = False + + if step % args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + + # We explicitly want to avoid relying on `accelerator.accumulate` for generation training + context = ( + functools.partial(self.accelerator.no_sync, model=model) + if i != len(batch_samples) - 1 + and self.accelerator.distributed_type != DistributedType.DEEPSPEED + else contextlib.nullcontext + ) + with context(): + tr_loss_step = self.training_step(model, inputs, num_items_in_batch) + + if ( + args.logging_nan_inf_filter + and not is_torch_xla_available() + and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) + ): + # if loss is nan or inf simply add the average of previous logged losses + tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) + else: + if tr_loss.device != tr_loss_step.device: + raise ValueError( + f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" + ) + tr_loss = tr_loss + tr_loss_step + + self.current_flos += float(self.floating_point_ops(inputs)) + + if do_sync_step: + # Since we perform prefetching, we need to manually set sync_gradients to True + self.accelerator.gradient_state._set_sync_gradients(True) + + # Gradient clipping + if args.max_grad_norm is not None and args.max_grad_norm > 0: + if is_sagemaker_mp_enabled() and args.fp16: + _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) + elif self.use_apex: + from apex import amp + + # Revert to normal clipping otherwise, handling Apex or full precision + _grad_norm = nn.utils.clip_grad_norm_( + amp.master_params(self.optimizer), + args.max_grad_norm, + ) + else: + grad_norm_context = contextlib.nullcontext + if self.is_tp_enabled: + from torch.distributed._tensor.experimental import implicit_replication + + grad_norm_context = implicit_replication + with grad_norm_context(): + _grad_norm = self.accelerator.clip_grad_norm_( + model.parameters(), + args.max_grad_norm, + ) + + if ( + is_accelerate_available() + and self.accelerator.distributed_type == DistributedType.DEEPSPEED + ): + grad_norm = model.get_global_grad_norm() + # In some cases the grad norm may not return a float + if hasattr(grad_norm, "item"): + grad_norm = grad_norm.item() + else: + grad_norm = _grad_norm + + self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) + + context = contextlib.nullcontext + if self.is_tp_enabled: + from torch.distributed._tensor.experimental import implicit_replication + + context = implicit_replication + + with context(): + self.optimizer.step() + + self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) + + # get leaning rate before update + learning_rate = self._get_learning_rate() + + if not self.accelerator.optimizer_step_was_skipped: + # Delay optimizer scheduling until metrics are generated + if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step() + + model.zero_grad() + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1) / steps_in_epoch + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + self._maybe_log_save_evaluate( + tr_loss, + grad_norm, + model, + trial, + epoch, + ignore_keys_for_eval, + start_time, + learning_rate=learning_rate, + ) + else: + self.control = self.callback_handler.on_substep_end(args, self.state, self.control) + + # PyTorch/XLA relies on the data loader to insert the mark_step for + # each step. Since we are breaking the loop early, we need to manually + # insert the mark_step here. + if self.control.should_epoch_stop or self.control.should_training_stop: + if is_torch_xla_available(): + xm.mark_step() + break + # We also need to break out of the nested loop + if self.control.should_epoch_stop or self.control.should_training_stop: + if is_torch_xla_available(): + xm.mark_step() + break + if step < 0: + logger.warning( + "There seems not to be a single sample in your epoch_iterator, stopping training at step" + f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" + f" num_steps ({max_steps}) higher than the number of available samples." + ) + self.control.should_training_stop = True + + self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) + self._maybe_log_save_evaluate( + tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate + ) + + if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + if is_torch_xla_available(): + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + else: + logger.warning( + "You enabled PyTorch/XLA debug metrics but you don't have a TPU " + "configured. Check your training configuration if this is unexpected." + ) + if self.control.should_training_stop: + break + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of training + delattr(self, "_past") + + logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") + if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: + self._load_best_model() + + # add remaining tr_loss + self._total_loss_scalar += tr_loss.item() + effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError + train_loss = self._total_loss_scalar / effective_global_step + + metrics = speed_metrics( + "train", + start_time, + num_samples=num_train_samples, + num_steps=self.state.max_steps, + num_tokens=num_train_tokens, + ) + self.store_flos() + metrics["total_flos"] = self.state.total_flos + metrics["train_loss"] = train_loss + + self.is_in_train = False + + self._memory_tracker.stop_and_update_metrics(metrics) + + self.log(metrics) + + run_dir = self._get_output_dir(trial) + checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) + + # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. + if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: + for checkpoint in checkpoints_sorted: + if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): + logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") + shutil.rmtree(checkpoint, ignore_errors=True) + + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + + # Wait for the checkpoint to be uploaded. + self._finish_current_push() + + # After training we make sure to retrieve back the original forward pass method + # for the embedding layer by removing the forward post hook. + if self.neftune_noise_alpha is not None: + self._deactivate_neftune(self.model) + + return TrainOutput(self.state.global_step, train_loss, metrics) + + def _get_output_dir(self, trial): + if self.hp_search_backend is not None and trial is not None: + if self.hp_search_backend == HPSearchBackend.OPTUNA: + run_id = trial.number + elif self.hp_search_backend == HPSearchBackend.RAY: + import ray.train + + run_id = ray.train.get_context().get_trial_id() + elif self.hp_search_backend == HPSearchBackend.SIGOPT: + run_id = trial.id + elif self.hp_search_backend == HPSearchBackend.WANDB: + import wandb + + run_id = wandb.run.id + run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}" + run_dir = os.path.join(self.args.output_dir, run_name) + else: + run_dir = self.args.output_dir + return run_dir + + def _load_from_checkpoint(self, resume_from_checkpoint, model=None): + if model is None: + model = self.model + + config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME) + adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME) + adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME) + weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME) + weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) + safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME) + safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME) + is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and ( + # this checks the FSDP state dict when `SHARDED_STATE_DICT` is used + any( + FSDP_MODEL_NAME in folder_name + for folder_name in os.listdir(resume_from_checkpoint) + if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name)) + ) + # this checks the FSDP state dict when `FULL_STATE_DICT` is used + or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin")) + ) + # if multiple adapters exist, they get saved in sub directories + adapter_subdirs = ( + [ + folder_name + for folder_name in os.listdir(resume_from_checkpoint) + if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name)) + and ( + os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_WEIGHTS_NAME)) + or os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_SAFE_WEIGHTS_NAME)) + ) + ] + if os.path.isdir(resume_from_checkpoint) + else [] + ) + + if is_fsdp_ckpt and not self.is_fsdp_enabled: + raise ValueError(f"Checkpoint found at {resume_from_checkpoint} is only supported when using PyTorch FSDP") + + if not ( + any( + os.path.isfile(f) + for f in [ + weights_file, + safe_weights_file, + weights_index_file, + safe_weights_index_file, + adapter_weights_file, + adapter_safe_weights_file, + ] + ) + or is_fsdp_ckpt + or adapter_subdirs + ): + raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") + + logger.info(f"Loading model from {resume_from_checkpoint}.") + + if os.path.isfile(config_file): + config = PretrainedConfig.from_json_file(config_file) + checkpoint_version = config.transformers_version + if checkpoint_version is not None and checkpoint_version != __version__: + logger.warning( + f"You are resuming training from a checkpoint trained with {checkpoint_version} of " + f"Transformers but your current version is {__version__}. This is not recommended and could " + "yield to errors or unwanted behaviors." + ) + + if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt: + # If the model is on the GPU, it still works! + if is_sagemaker_mp_enabled(): + if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")): + # If the 'user_content.pt' file exists, load with the new smp api. + # Checkpoint must have been saved with the new smp api. + smp.resume_from_checkpoint( + path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False + ) + else: + # If the 'user_content.pt' file does NOT exist, load with the old smp api. + # Checkpoint must have been saved with the old smp api. + if hasattr(self.args, "fp16") and self.args.fp16 is True: + logger.warning( + "Enabling FP16 and loading from smp < 1.10 checkpoint together is not supported." + ) + check_torch_load_is_safe() + state_dict = torch.load(weights_file, map_location="cpu", weights_only=True) + # Required for smp to not auto-translate state_dict from hf to smp (is already smp). + state_dict["_smp_is_partial"] = False + load_result = model.load_state_dict(state_dict, strict=True) + # release memory + del state_dict + elif self.is_fsdp_enabled: + load_fsdp_model( + self.accelerator.state.fsdp_plugin, + self.accelerator, + model, + resume_from_checkpoint, + **_get_fsdp_ckpt_kwargs(), + ) + else: + # We load the model state dict on the CPU to avoid an OOM error. + if self.args.save_safetensors and os.path.isfile(safe_weights_file): + state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu") + else: + check_torch_load_is_safe() + state_dict = torch.load(weights_file, map_location="cpu", weights_only=True) + + # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 + # which takes *args instead of **kwargs + load_result = model.load_state_dict(state_dict, False) + # release memory + del state_dict + self._issue_warnings_after_load(load_result) + + # Load adapters following PR # 24096 + elif _is_peft_model(model): + # If train a model using PEFT & LoRA, assume that adapter have been saved properly. + # TODO: in the future support only specific min PEFT versions + if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr( + model, "load_adapter" + ): + if os.path.exists(resume_from_checkpoint): + # For BC for older PEFT versions + if hasattr(model, "active_adapters"): + active_adapters = model.active_adapters + if len(active_adapters) > 1: + logger.warning("Multiple active adapters detected will only consider the first adapter") + active_adapter = active_adapters[0] + else: + active_adapter = model.active_adapter + + if adapter_subdirs: + for subdir_name in adapter_subdirs: + peft_id = os.path.join(resume_from_checkpoint, subdir_name) + model.load_adapter(peft_id, subdir_name, is_trainable=(subdir_name == active_adapter)) + model.set_adapter(active_adapter) + else: + model.load_adapter(resume_from_checkpoint, active_adapter, is_trainable=True) + else: + logger.warning( + "The intermediate checkpoints of PEFT may not be saved correctly, " + f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. " + "Check some examples here: https://github.com/huggingface/peft/issues/96" + ) + else: + logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed") + else: + # We load the sharded checkpoint + load_result = load_sharded_checkpoint( + model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors + ) + if not is_sagemaker_mp_enabled(): + self._issue_warnings_after_load(load_result) + + def _load_best_model(self): + logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") + best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) + best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME) + best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME) + best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME) + + model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + if self.is_deepspeed_enabled: + deepspeed_load_checkpoint( + self.model_wrapped, + self.state.best_model_checkpoint, + load_module_strict=not _is_peft_model(self.model), + ) + elif self.is_fsdp_enabled: + load_result = load_fsdp_model( + self.accelerator.state.fsdp_plugin, + self.accelerator, + model, + self.state.best_model_checkpoint, + **_get_fsdp_ckpt_kwargs(), + ) + elif ( + os.path.exists(best_model_path) + or os.path.exists(best_safe_model_path) + or os.path.exists(best_adapter_model_path) + or os.path.exists(best_safe_adapter_model_path) + ): + has_been_loaded = True + if is_sagemaker_mp_enabled(): + if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): + # If the 'user_content.pt' file exists, load with the new smp api. + # Checkpoint must have been saved with the new smp api. + smp.resume_from_checkpoint( + path=self.state.best_model_checkpoint, + tag=WEIGHTS_NAME, + partial=False, + load_optimizer=False, + ) + else: + # If the 'user_content.pt' file does NOT exist, load with the old smp api. + # Checkpoint must have been saved with the old smp api. + if self.args.save_safetensors and os.path.isfile(best_safe_model_path): + state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") + else: + check_torch_load_is_safe() + state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True) + + state_dict["_smp_is_partial"] = False + load_result = model.load_state_dict(state_dict, strict=True) + else: + if _is_peft_model(model): + # If train a model using PEFT & LoRA, assume that adapter have been saved properly. + # TODO: in the future support only specific min PEFT versions + if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr( + model, "load_adapter" + ): + # For BC for older PEFT versions + if hasattr(model, "active_adapters"): + active_adapter = model.active_adapters[0] + if len(model.active_adapters) > 1: + logger.warning("Detected multiple active adapters, will only consider the first one") + else: + active_adapter = model.active_adapter + + if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path): + try: + model.load_adapter(self.state.best_model_checkpoint, active_adapter) + except RuntimeError as exc: + if model.peft_config[active_adapter].is_prompt_learning: + # for context: https://github.com/huggingface/peft/issues/2256 + msg = ( + "When using prompt learning PEFT methods such as " + f"{model.peft_config[active_adapter].peft_type.value}, setting " + "load_best_model_at_end=True can lead to errors, it is recommended " + "to set this to False and to load the model manually from the checkpoint " + "directory using PeftModel.from_pretrained(base_model, ) after training " + "has finished." + ) + raise RuntimeError(msg) from exc + else: + raise + # Load_adapter has no return value present, modify it when appropriate. + from torch.nn.modules.module import _IncompatibleKeys + + load_result = _IncompatibleKeys([], []) + else: + logger.warning( + "The intermediate checkpoints of PEFT may not be saved correctly, " + f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. " + "Check some examples here: https://github.com/huggingface/peft/issues/96" + ) + has_been_loaded = False + else: + logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed") + has_been_loaded = False + else: + # We load the model state dict on the CPU to avoid an OOM error. + if self.args.save_safetensors and os.path.isfile(best_safe_model_path): + state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") + else: + check_torch_load_is_safe() + state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True) + + # If the model is on the GPU, it still works! + # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 + # which takes *args instead of **kwargs + load_result = model.load_state_dict(state_dict, False) + if not is_sagemaker_mp_enabled() and has_been_loaded: + self._issue_warnings_after_load(load_result) + elif os.path.exists(os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_INDEX_NAME)) or os.path.exists( + os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME) + ): + load_result = load_sharded_checkpoint( + model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled() + ) + if not is_sagemaker_mp_enabled(): + self._issue_warnings_after_load(load_result) + else: + logger.warning( + f"Could not locate the best model at {best_model_path}, if you are running a distributed training " + "on multiple nodes, you should activate `--save_on_each_node`." + ) + + def _issue_warnings_after_load(self, load_result): + if len(load_result.missing_keys) != 0: + if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set( + self.model._keys_to_ignore_on_save + ): + self.model.tie_weights() + else: + logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.") + if len(load_result.unexpected_keys) != 0: + logger.warning( + f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}." + ) + + def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False): + metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) + self._report_to_hp_search(trial, self.state.global_step, metrics) + + # Run delayed LR scheduler now that metrics are populated + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) and not skip_scheduler: + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + try: + self.lr_scheduler.step(metrics[metric_to_check]) + except KeyError as exc: + raise KeyError( + f"The `metric_for_best_model` training argument is set to '{metric_to_check}', " + f"which is not found in the evaluation metrics. " + f"The available evaluation metrics are: {list(metrics.keys())}. " + f"Please ensure that the `compute_metrics` function returns a dictionary that includes '{metric_to_check}' or " + f"consider changing the `metric_for_best_model` via the TrainingArguments." + ) from exc + return metrics + + def _maybe_log_save_evaluate( + self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=None + ): + if self.control.should_log and self.state.global_step > self._globalstep_last_logged: + if is_torch_xla_available(): + xm.mark_step() + + logs: dict[str, float] = {} + + # all_gather + mean() to get average loss over all processes + tr_loss_scalar = self._nested_gather(tr_loss).mean().item() + + # reset tr_loss to zero + tr_loss -= tr_loss + + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + if grad_norm is not None: + logs["grad_norm"] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm + if learning_rate is not None: + logs["learning_rate"] = learning_rate + else: + logs["learning_rate"] = self._get_learning_rate() + + self._total_loss_scalar += tr_loss_scalar + self._globalstep_last_logged = self.state.global_step + self.store_flos() + + self.log(logs, start_time) + + metrics = None + if self.control.should_evaluate: + metrics = self._evaluate(trial, ignore_keys_for_eval) + is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial) + + if self.args.save_strategy == SaveStrategy.BEST: + self.control.should_save = is_new_best_metric + + if self.control.should_save: + self._save_checkpoint(model, trial) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + def _load_rng_state(self, checkpoint): + # Load RNG states from `checkpoint` + if checkpoint is None: + return + + if self.args.world_size > 1: + process_index = self.args.process_index + rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth") + if not os.path.isfile(rng_file): + logger.info( + f"Didn't find an RNG file for process {process_index}, if you are resuming a training that " + "wasn't launched in a distributed fashion, reproducibility is not guaranteed." + ) + return + else: + rng_file = os.path.join(checkpoint, "rng_state.pth") + if not os.path.isfile(rng_file): + logger.info( + "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " + "fashion, reproducibility is not guaranteed." + ) + return + + with safe_globals(): + checkpoint_rng_state = torch.load(rng_file) + random.setstate(checkpoint_rng_state["python"]) + np.random.set_state(checkpoint_rng_state["numpy"]) + torch.random.set_rng_state(checkpoint_rng_state["cpu"]) + if is_torch_xla_available(): + xm.set_rng_state(checkpoint_rng_state["xla"]) + + is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED + if torch.cuda.is_available(): + set_rng_state_for_device("CUDA", torch.cuda, checkpoint_rng_state, is_distributed) + if is_torch_npu_available(): + set_rng_state_for_device("NPU", torch.npu, checkpoint_rng_state, is_distributed) + if is_torch_hpu_available(): + set_rng_state_for_device("HPU", torch.hpu, checkpoint_rng_state, is_distributed) + if is_torch_mlu_available(): + set_rng_state_for_device("MLU", torch.mlu, checkpoint_rng_state, is_distributed) + if is_torch_musa_available(): + set_rng_state_for_device("MUSA", torch.musa, checkpoint_rng_state, is_distributed) + + def _determine_best_metric(self, metrics, trial): + """ + Determine if the model should be saved based on the evaluation metrics. + + Returns: + bool: True if a new best metric was found, else False + """ + is_new_best_metric = False + + if self.args.metric_for_best_model is not None: + metric_to_check = self.args.metric_for_best_model + + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + + try: + metric_value = metrics[metric_to_check] + except KeyError as exc: + raise KeyError( + f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. " + f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments." + ) from exc + + operator = np.greater if self.args.greater_is_better else np.less + + if self.state.best_metric is None: + self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf") + + if operator(metric_value, self.state.best_metric): + self.state.best_metric = metric_value + + if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH]: + self.state.best_global_step = self.state.global_step + + is_new_best_metric = True + + return is_new_best_metric + + def _save_checkpoint(self, model, trial): + # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we + # want to save except FullyShardedDDP. + # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" + + # Save model checkpoint + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + + if self.hp_search_backend is None and trial is None: + self.store_flos() + + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + self.save_model(output_dir, _internal_call=True) + + if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH] and self.state.best_global_step: + # Wait for everyone to get here so we are sure the model has been saved by process 0 + # before we check if the best_checkpoint_dir exists + if is_torch_xla_available(): + xm.rendezvous("load_best_model_at_end") + elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: + dist.barrier() + elif is_sagemaker_mp_enabled(): + smp.barrier() + + best_checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.best_global_step}" + best_checkpoint_dir = os.path.join(run_dir, best_checkpoint_folder) + + if os.path.exists(best_checkpoint_dir): + self.state.best_model_checkpoint = best_checkpoint_dir + + if not self.args.save_only_model: + # Save optimizer and scheduler + self._save_optimizer_and_scheduler(output_dir) + self._save_scaler(output_dir) + # Save RNG state + self._save_rng_state(output_dir) + + # Save the Trainer state + if self.args.should_save: + # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently + for cb in [ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ]: + cb_name = cb.__class__.__name__ + cb_state = cb.state() + if isinstance(self.state.stateful_callbacks[cb_name], list): + self.state.stateful_callbacks[cb_name].append(cb_state) + else: + self.state.stateful_callbacks[cb_name] = cb_state + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + + if self.args.push_to_hub: + self._push_from_checkpoint(output_dir) + + # Maybe delete some older checkpoints. + if self.args.should_save: + # we use mtime as default, filesystems without mtime support will be detected in `_sorted_checkpoints` + self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + + def _save_rng_state(self, output_dir): + # Save RNG state in non-distributed training + rng_states = { + "python": random.getstate(), + "numpy": np.random.get_state(), + "cpu": torch.random.get_rng_state(), + } + if torch.cuda.is_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + # In non distributed, we save the global CUDA RNG state (will take care of DataParallel) + rng_states["cuda"] = torch.cuda.random.get_rng_state_all() + else: + rng_states["cuda"] = torch.cuda.random.get_rng_state() + + if is_torch_xla_available(): + rng_states["xla"] = xm.get_rng_state() + + if is_torch_npu_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + rng_states["npu"] = torch.npu.random.get_rng_state_all() + else: + rng_states["npu"] = torch.npu.random.get_rng_state() + + if is_torch_hpu_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + rng_states["hpu"] = torch.hpu.random.get_rng_state_all() + else: + rng_states["hpu"] = torch.hpu.random.get_rng_state() + + if is_torch_mlu_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + rng_states["mlu"] = torch.mlu.random.get_rng_state_all() + else: + rng_states["mlu"] = torch.mlu.random.get_rng_state() + + if is_torch_musa_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + rng_states["musa"] = torch.musa.get_rng_state_all() + else: + rng_states["musa"] = torch.musa.get_rng_state() + + # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may + # not yet exist. + os.makedirs(output_dir, exist_ok=True) + + if self.args.world_size <= 1: + torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) + else: + torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth")) + + def _save_optimizer_and_scheduler(self, output_dir): + if is_torch_xla_available(): + xm.rendezvous("saving_optimizer_states") + if self.is_fsdp_xla_v1_enabled: + optm = { + "optimizer": self.optimizer.state_dict(), + "shard_metadata": self.model.get_shard_metadata(), + } + xm.save( + optm, + os.path.join( + output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}" + ), + master_only=False, + ) + else: + xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + with warnings.catch_warnings(record=True) as caught_warnings: + xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + elif is_sagemaker_mp_enabled(): + opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False) + smp.barrier() + if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state: + smp.save( + opt_state_dict, + os.path.join(output_dir, OPTIMIZER_NAME), + partial=True, + v3=smp.state.cfg.shard_optimizer_state, + ) + elif self.is_deepspeed_enabled: + # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed + # config `stage3_gather_16bit_weights_on_model_save` is True + accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set( + inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys() + ) + if accept_exclude_frozen_parameters and _is_peft_model(self.model): + self.model_wrapped.save_checkpoint(output_dir, exclude_frozen_parameters=True) + else: + self.model_wrapped.save_checkpoint(output_dir) + elif self.is_fsdp_enabled: + # save fsdp specific ckpt for resuming from ckpt + save_fsdp_model( + self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir, **_get_fsdp_ckpt_kwargs() + ) + save_fsdp_optimizer( + self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir + ) + elif self.args.should_save: + # deepspeed.save_checkpoint above saves model/optim/sched + torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + + # Save SCHEDULER & SCALER + is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance( + self.lr_scheduler, DeepSpeedSchedulerWrapper + ) + if ( + self.args.should_save + and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler) + and not is_torch_xla_available() + ): + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + + def _load_optimizer_and_scheduler(self, checkpoint): + """If optimizer and scheduler states exist, load them.""" + if checkpoint is None: + return + + if self.is_deepspeed_enabled: + # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init + if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper): + with warnings.catch_warnings(record=True) as caught_warnings: + check_torch_load_is_safe() + self.lr_scheduler.load_state_dict( + torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True) + ) + reissue_pt_warnings(caught_warnings) + return + + checkpoint_file_exists = ( + glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") + if is_sagemaker_mp_enabled() + else ( + os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) + or os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME_BIN)) + or ( + os.path.isdir(checkpoint) + and any( + OPTIMIZER_NAME_BIN.split(".")[0] in folder_name + for folder_name in os.listdir(checkpoint) + if os.path.isdir(os.path.join(checkpoint, folder_name)) + ) + ) + ) + ) + checkpoint_file_exists = ( + glob.glob(os.path.join(checkpoint, f"rank*-of-{self.args.world_size}-{OPTIMIZER_NAME}")) + if self.is_fsdp_xla_v1_enabled + else checkpoint_file_exists + ) + if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): + # Load in optimizer and scheduler states + if is_torch_xla_available(): + # On TPU we have to take some extra precautions to properly load the states on the right device. + if self.is_fsdp_xla_v1_enabled: + check_torch_load_is_safe() + optimizer_state = torch.load( + os.path.join( + checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}" + ), + map_location="cpu", + weights_only=True, + ) + # We only need `optimizer` when resuming from checkpoint + optimizer_state = optimizer_state["optimizer"] + else: + check_torch_load_is_safe() + optimizer_state = torch.load( + os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu", weights_only=True + ) + with warnings.catch_warnings(record=True) as caught_warnings: + check_torch_load_is_safe() + lr_scheduler_state = torch.load( + os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu", weights_only=True + ) + reissue_pt_warnings(caught_warnings) + + xm.send_cpu_data_to_device(optimizer_state, self.args.device) + xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device) + + self.optimizer.load_state_dict(optimizer_state) + self.lr_scheduler.load_state_dict(lr_scheduler_state) + else: + if is_sagemaker_mp_enabled(): + if os.path.isfile(os.path.join(checkpoint, "user_content.pt")): + # Optimizer checkpoint was saved with smp >= 1.10 + def opt_load_hook(mod, opt): + opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) + + else: + # Optimizer checkpoint was saved with smp < 1.10 + def opt_load_hook(mod, opt): + if IS_SAGEMAKER_MP_POST_1_10: + opt.load_state_dict( + smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True) + ) + else: + opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) + + self.model_wrapped.register_post_step_hook(opt_load_hook) + else: + # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models. + # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more + # likely to get OOM on CPU (since we load num_gpu times the optimizer state + map_location = self.args.device if self.args.world_size > 1 else "cpu" + if self.is_fsdp_enabled: + load_fsdp_optimizer( + self.accelerator.state.fsdp_plugin, + self.accelerator, + self.optimizer, + self.model, + checkpoint, + **_get_fsdp_ckpt_kwargs(), + ) + else: + check_torch_load_is_safe() + self.optimizer.load_state_dict( + torch.load( + os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location, weights_only=True + ) + ) + with warnings.catch_warnings(record=True) as caught_warnings: + check_torch_load_is_safe() + self.lr_scheduler.load_state_dict( + torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True) + ) + reissue_pt_warnings(caught_warnings) + + def _save_scaler(self, output_dir): + # See if there is a scaler attribute + try: + scaler = self.accelerator.scaler + except AttributeError: + return + if scaler is None: + return + if is_torch_xla_available(): + xm.rendezvous("saving_scaler_state") + with warnings.catch_warnings(record=True) as caught_warnings: + xm.save(self.accelerator.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + reissue_pt_warnings(caught_warnings) + + # Save SCALER + if self.args.should_save and not is_torch_xla_available(): + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.accelerator.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + reissue_pt_warnings(caught_warnings) + + def _load_scaler(self, checkpoint): + """If scaler state exists, load it.""" + if checkpoint is None: + return + + checkpoint_file_exists = os.path.isfile(os.path.join(checkpoint, SCALER_NAME)) + + if checkpoint_file_exists: + # On TPU we have to take some extra precautions to properly load the states on the right device. + # Load in scaler states + if is_torch_xla_available(): + with warnings.catch_warnings(record=True) as caught_warnings: + check_torch_load_is_safe() + scaler_state = torch.load( + os.path.join(checkpoint, SCALER_NAME), map_location="cpu", weights_only=True + ) + reissue_pt_warnings(caught_warnings) + xm.send_cpu_data_to_device(scaler_state, self.args.device) + self.accelerator.scaler.load_state_dict(scaler_state) + else: + with warnings.catch_warnings(record=True) as caught_warnings: + check_torch_load_is_safe() + self.accelerator.scaler.load_state_dict( + torch.load(os.path.join(checkpoint, SCALER_NAME), weights_only=True) + ) + reissue_pt_warnings(caught_warnings) + + def _load_callback_state(self): + """If callback states exist and were passed in, restore their states if enabled""" + if not self.args.restore_callback_states_from_checkpoint: + return + # Callback states are stored in stateful_callbacks + not_found = [] + new_callbacks = [] + original_callbacks = self.callback_handler.callbacks + [self.control] + for stored_callback, data in self.state.stateful_callbacks.items(): + if not isinstance(data, list): + data = [data] + if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks): + # We can load/restore from multiple callbacks of the same type. + duplicates = [ + callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback + ] + for callback, callback_data in zip(duplicates, data): + args = callback_data.get("args", {}) + attributes = callback_data.get("attributes", {}) + new_callback = type(callback)(**args) + for attribute, value in attributes.items(): + setattr(new_callback, attribute, value) + if isinstance(callback, TrainerControl): + # Specifically for restoring the `control` state + self.control = new_callback + else: + new_callbacks.append(new_callback) + # We remove the existing callback and add it to the list of new callbacks + self.callback_handler.remove_callback(type(new_callback)) + logger.info("Continuing training from checkpoint, restoring any callbacks that were passed in") + else: + not_found.append(stored_callback) + if len(not_found) > 0: + logger.warning( + f"Checkpoint included callbacks not included in current configuration. Ignoring. ({', '.join(not_found)})" + ) + for callback in new_callbacks: + self.callback_handler.add_callback(callback) + + def hyperparameter_search( + self, + hp_space: Optional[Callable[["optuna.Trial"], dict[str, float]]] = None, + compute_objective: Optional[Callable[[dict[str, float]], float]] = None, + n_trials: int = 20, + direction: Union[str, list[str]] = "minimize", + backend: Optional[Union["str", HPSearchBackend]] = None, + hp_name: Optional[Callable[["optuna.Trial"], str]] = None, + **kwargs, + ) -> Union[BestRun, list[BestRun]]: + """ + Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined + by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided, + the sum of all metrics otherwise. + + + + To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to + reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to + subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom + optimizer/scheduler. + + + + Args: + hp_space (`Callable[["optuna.Trial"], dict[str, float]]`, *optional*): + A function that defines the hyperparameter search space. Will default to + [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or + [`~trainer_utils.default_hp_space_sigopt`] depending on your backend. + compute_objective (`Callable[[dict[str, float]], float]`, *optional*): + A function computing the objective to minimize or maximize from the metrics returned by the `evaluate` + method. Will default to [`~trainer_utils.default_compute_objective`]. + n_trials (`int`, *optional*, defaults to 100): + The number of trial runs to test. + direction (`str` or `list[str]`, *optional*, defaults to `"minimize"`): + If it's single objective optimization, direction is `str`, can be `"minimize"` or `"maximize"`, you + should pick `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or + several metrics. If it's multi objectives optimization, direction is `list[str]`, can be List of + `"minimize"` and `"maximize"`, you should pick `"minimize"` when optimizing the validation loss, + `"maximize"` when optimizing one or several metrics. + backend (`str` or [`~training_utils.HPSearchBackend`], *optional*): + The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending + on which one is installed. If all are installed, will default to optuna. + hp_name (`Callable[["optuna.Trial"], str]]`, *optional*): + A function that defines the trial/run name. Will default to None. + kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments for each backend: + + - `optuna`: parameters from + [optuna.study.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html) + and also the parameters `timeout`, `n_jobs` and `gc_after_trial` from + [optuna.study.Study.optimize](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.optimize) + - `ray`: parameters from [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run). + If `resources_per_trial` is not set in the `kwargs`, it defaults to 1 CPU core and 1 GPU (if available). + If `progress_reporter` is not set in the `kwargs`, + [ray.tune.CLIReporter](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.CLIReporter.html) is used. + - `sigopt`: the parameter `proxies` from + [sigopt.Connection.set_proxies](https://docs.sigopt.com/support/faq#how-do-i-use-sigopt-with-a-proxy). + + Returns: + [`trainer_utils.BestRun` or `list[trainer_utils.BestRun]`]: All the information about the best run or best + runs for multi-objective optimization. Experiment summary can be found in `run_summary` attribute for Ray + backend. + """ + if backend is None: + backend = default_hp_search_backend() + backend = HPSearchBackend(backend) + backend_obj = ALL_HYPERPARAMETER_SEARCH_BACKENDS[backend]() + backend_obj.ensure_available() + self.hp_search_backend = backend + if self.model_init is None: + raise RuntimeError( + "To use hyperparameter search, you need to pass your model through a model_init function." + ) + + self.hp_space = backend_obj.default_hp_space if hp_space is None else hp_space + self.hp_name = hp_name + self.compute_objective = default_compute_objective if compute_objective is None else compute_objective + + best_run = backend_obj.run(self, n_trials, direction, **kwargs) + + self.hp_search_backend = None + return best_run + + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + """ + Log `logs` on the various objects watching training. + + Subclass and override this method to inject custom behavior. + + Args: + logs (`dict[str, float]`): + The values to log. + start_time (`Optional[float]`): + The start of training. + """ + if self.state.epoch is not None: + logs["epoch"] = self.state.epoch + if self.args.include_num_input_tokens_seen != "no": + logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen + if start_time is not None: + logs.update(speed_metrics("train", start_time, num_tokens=self.state.num_input_tokens_seen)) + + output = {**logs, **{"step": self.state.global_step}} + self.state.log_history.append(output) + self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) + + def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]: + """ + Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. + """ + if isinstance(data, Mapping): + return type(data)({k: self._prepare_input(v) for k, v in data.items()}) + elif isinstance(data, (tuple, list)): + return type(data)(self._prepare_input(v) for v in data) + elif isinstance(data, torch.Tensor): + kwargs = {"device": self.args.device} + if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)): + # NLP models inputs are int/uint and those get adjusted to the right dtype of the + # embedding. Other models such as wav2vec2's inputs are already float and thus + # may need special handling to match the dtypes of the model + kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()}) + return data.to(**kwargs) + return data + + def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: + """ + Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and + handling potential state. + """ + inputs = self._prepare_input(inputs) + if len(inputs) == 0: + raise ValueError( + "The batch received was empty, your model won't be able to train on it. Double-check that your " + f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}." + ) + if self.args.past_index >= 0 and self._past is not None: + inputs["mems"] = self._past + + return inputs + + def _is_attention_mask_causal(self, attention_mask): + """ + Check if an attention mask is causal (compatible with causal attention). + Context parallelism only supports causal attention patterns. This function + checks if the provided attention mask is compatible. + + Args: + attention_mask (torch.Tensor): The attention mask to check + + Returns: + bool: True if the mask is causal or compatible with causal attention + """ + if attention_mask is None: + return True # No mask is considered causal (model uses default causal masking) + + # Handle different mask dimensions + if attention_mask.dim() == 2: + # (batch_size, seq_len) - standard padding mask, compatible with causal attention + return True + elif attention_mask.dim() in [3, 4]: + # (batch_size, seq_len, seq_len) or (batch_size, num_heads, seq_len, seq_len) + # Check if it's lower triangular (causal) + seq_len = attention_mask.shape[-1] + if seq_len <= 1: + return True # Single token or empty is always causal + + # Take first batch and head (if 4D) for checking pattern + if attention_mask.dim() == 4: + mask = attention_mask[0, 0] # First batch, first head + else: + mask = attention_mask[0] # First batch + + # Check if upper triangular part is masked (should be 0 or very negative for causal) + upper_triangular = torch.triu(mask, diagonal=1) + + # For causal masks, upper triangular should be 0 or very negative (like -inf) + # Use a reasonable threshold to handle float precision issues + is_causal = torch.all(upper_triangular <= 1e-6) or torch.all(upper_triangular < -1e4) + return is_causal.item() if isinstance(is_causal, torch.Tensor) else is_causal + + # For unknown dimensions, be conservative and reject + return False + + def _prepare_context_parallel_inputs(self, model, inputs: dict[str, Union[torch.Tensor, Any]]): + """ + Prepare inputs for context parallelism by setting up buffers and validation. + + Args: + model: The model being trained + inputs: Input tensors to prepare + + Returns: + tuple: (context_manager, prepared_inputs) where context_manager is either + the context parallelism wrapper or a no-op context + """ + if ( + getattr(self.accelerator, "parallelism_config", None) is not None + and self.accelerator.parallelism_config.cp_enabled + ): + if hasattr(model, "config"): + if model.config._attn_implementation != "sdpa": + raise ValueError( + f"Context parallelism is supported only with SDPA attention, you are using {model.config._attn_implementation}." + ) + + if "position_ids" not in inputs: + logger.warning_once("Position IDs not found in the inputs, generating manually") + inputs["position_ids"] = torch.arange( + inputs["input_ids"].size(1), device=inputs["input_ids"].device + ).expand(inputs["input_ids"].size(0), -1) + if "shift_labels" not in inputs: + logger.warning_once("Shift labels not found in the inputs, shifting manually") + if "labels" in inputs: + _ignore_index = -100 + labels = nn.functional.pad(inputs["labels"], (0, 1), value=_ignore_index) + inputs["shift_labels"] = labels[:, 1:].contiguous() + + buffers = [] + buffer_seq_dims = [] + + if "input_ids" in inputs: + buffers.append(inputs["input_ids"]) + buffer_seq_dims.append(1) # Sequence dimension + if "labels" in inputs: + buffers.append(inputs["labels"]) + buffer_seq_dims.append(1) + if "shift_labels" in inputs: + buffers.append(inputs["shift_labels"]) + buffer_seq_dims.append(1) + # Add attention_mask to buffers for context parallel splitting (only if causal) + if "attention_mask" in inputs: + # Only validate causal mask once for performance + if not getattr(self, "_attn_mask_causal_checked", False): + # Context parallel currently doesn't support other masks than causal + # Accelerate applies hooks to replace mask with is_causal arg in SDPA + # Check if the mask is really causal and if not throw an error + attention_mask = inputs["attention_mask"] + if not self._is_attention_mask_causal(attention_mask): + raise ValueError( + "Context parallelism only supports causal attention masks. " + "The provided attention_mask is not causal. " + "Please ensure your data uses causal masking (lower triangular) " + "or remove the attention_mask to use the model's default causal masking." + ) + self._attn_mask_causal_checked = True + if self._attn_mask_causal_checked: + # Add to buffers only after validation (or if validation already passed) + attention_mask = inputs["attention_mask"] + if attention_mask.dim() == 2: + buffers.append(attention_mask) + buffer_seq_dims.append(1) + else: + # Other dimensionality; keep as-is without sharding to avoid incorrect splits + pass + # Include position_ids in context parallelism splitting + if "position_ids" in inputs and inputs["position_ids"] is not None: + buffers.append(inputs["position_ids"]) + buffer_seq_dims.append(1) + + return partial( + self.accelerator.maybe_context_parallel, + buffers=buffers, + buffer_seq_dims=buffer_seq_dims, + no_restore_buffers=set(buffers), + ), inputs + + return contextlib.nullcontext, inputs + + def compute_loss_context_manager(self): + """ + A helper wrapper to group together context managers. + """ + ctx_stack = contextlib.ExitStack() + + autocast_ctx = self.autocast_smart_context_manager() + if not isinstance(autocast_ctx, contextlib.nullcontext): + ctx_stack.enter_context(autocast_ctx) + + return ctx_stack + + def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): + """ + A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired + arguments, depending on the situation. + """ + if self.use_cpu_amp: + ctx_manager = torch.autocast(device_type="cpu", cache_enabled=cache_enabled, dtype=self.amp_dtype) + else: + ctx_manager = contextlib.nullcontext() + + return ctx_manager + + def training_step( + self, + model: nn.Module, + inputs: dict[str, Union[torch.Tensor, Any]], + num_items_in_batch: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to train. + inputs (`dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `torch.Tensor`: The tensor with training loss on this batch. + """ + # Prepare buffers for context parallelism + + cp_context, inputs = self._prepare_context_parallel_inputs(model, inputs) + + # Context manager is no-op if CP isn't enabled + with cp_context(): + model.train() + if hasattr(self.optimizer, "train") and callable(self.optimizer.train): + self.optimizer.train() + + inputs = self._prepare_inputs(inputs) + if is_sagemaker_mp_enabled(): + loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) + return loss_mb.reduce_mean().detach().to(self.args.device) + + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) + + del inputs + if ( + self.args.torch_empty_cache_steps is not None + and self.state.global_step % self.args.torch_empty_cache_steps == 0 + ): + if is_torch_xpu_available(): + torch.xpu.empty_cache() + elif is_torch_mlu_available(): + torch.mlu.empty_cache() + elif is_torch_musa_available(): + torch.musa.empty_cache() + elif is_torch_npu_available(): + torch.npu.empty_cache() + elif is_torch_mps_available(): + torch.mps.empty_cache() + elif is_torch_hpu_available(): + logger.warning( + "`torch_empty_cache_steps` is set but HPU device/backend does not support empty_cache()." + ) + else: + torch.cuda.empty_cache() + + kwargs = {} + + # For LOMO optimizers you need to explicitly use the learning rate + if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + kwargs["learning_rate"] = self._get_learning_rate() + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.use_apex: + from apex import amp + + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + # Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss + if ( + not self.model_accepts_loss_kwargs or num_items_in_batch is None + ) and self.compute_loss_func is None: + # If the model does not accept loss kwargs, we need to normalize the loss by the number of gradient accumulation steps + loss = loss / self.current_gradient_accumulation_steps + + # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled + # https://github.com/huggingface/transformers/pull/35808 + if self.accelerator.distributed_type == DistributedType.DEEPSPEED: + kwargs["scale_wrt_gas"] = False + + self.accelerator.backward(loss, **kwargs) + + return loss.detach() + + def compute_loss( + self, + model: nn.Module, + inputs: dict[str, Union[torch.Tensor, Any]], + return_outputs: bool = False, + num_items_in_batch: Optional[torch.Tensor] = None, + ): + """ + How the loss is computed by Trainer. By default, all models return the loss in the first element. + + Args: + model (`nn.Module`): + The model to compute the loss for. + inputs (`dict[str, Union[torch.Tensor, Any]]`): + The input data for the model. + return_outputs (`bool`, *optional*, defaults to `False`): + Whether to return the model outputs along with the loss. + num_items_in_batch (Optional[torch.Tensor], *optional*): + The number of items in the batch. If num_items_in_batch is not passed, + + Returns: + The loss of the model along with its output if return_outputs was set to True + + Subclass and override for custom behavior. If you are not using `num_items_in_batch` when computing your loss, + make sure to overwrite `self.model_accepts_loss_kwargs` to `False`. Otherwise, the loss calculating might be slightly inaccurate when performing gradient accumulation. + """ + if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs: + labels = inputs.pop("labels") + else: + labels = None + if self.model_accepts_loss_kwargs: + kwargs = {} + if num_items_in_batch is not None: + kwargs["num_items_in_batch"] = num_items_in_batch + inputs = {**inputs, **kwargs} + outputs = model(**inputs) + # Save past state if it exists + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + + # User-defined compute_loss function + if self.compute_loss_func is not None: + if labels is None: + logger.warning( + "Trainer: `compute_loss_func` is defined but `labels=None`. " + "Your custom loss function will still be called with labels=None. " + ) + loss = self.compute_loss_func( + outputs, + labels, + num_items_in_batch=num_items_in_batch, + ) + # Default HF loss handling (label smoothing) if no custom loss function + elif labels is not None: + unwrapped_model = self.accelerator.unwrap_model(model) + model_name = ( + unwrapped_model.base_model.model._get_name() + if _is_peft_model(unwrapped_model) + else unwrapped_model._get_name() + ) + if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + loss = self.label_smoother(outputs, labels, shift_labels=True) + else: + loss = self.label_smoother(outputs, labels) + else: + if isinstance(outputs, dict) and "loss" not in outputs: + raise ValueError( + "The model did not return a loss from the inputs, only the following keys: " + f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." + ) + # We don't use .loss here since the model may return tuples instead of ModelOutput. + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + + if ( + self.args.average_tokens_across_devices + and (self.model_accepts_loss_kwargs or self.compute_loss_func) + and num_items_in_batch is not None + ): + loss *= self.accelerator.num_processes if self.args.n_gpu <= 1 else self.args.n_gpu + + return (loss, outputs) if return_outputs else loss + + def is_local_process_zero(self) -> bool: + """ + Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several + machines) main process. + """ + return self.args.local_process_index == 0 + + def is_world_process_zero(self) -> bool: + """ + Whether or not this process is the global main process (when training in a distributed fashion on several + machines, this is only going to be `True` for one process). + """ + # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global + # process index. + if is_sagemaker_mp_enabled(): + return smp.rank() == 0 + else: + return self.args.process_index == 0 + + def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): + """ + Will save the model, so you can reload it using `from_pretrained()`. + + Will only save from the main process. + """ + + if output_dir is None: + output_dir = self.args.output_dir + + if is_torch_xla_available(): + self._save_tpu(output_dir) + elif is_sagemaker_mp_enabled(): + # Calling the state_dict needs to be done on the wrapped model and on all processes. + os.makedirs(output_dir, exist_ok=True) + state_dict = self.model_wrapped.state_dict() + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) + if IS_SAGEMAKER_MP_POST_1_10: + # 'user_content.pt' indicates model state_dict saved with smp >= 1.10 + Path(os.path.join(output_dir, "user_content.pt")).touch() + # We are in N-D parallelism if we have parallelism_config set, so we check accelerate if we're on a to_save rank + elif getattr(self.accelerator, "parallelism_config", None) is not None: + if self.accelerator.should_save_model: + self._save(output_dir) + # If we drop to here, we're in 1D parallelism, so all ranks need to go to `save_pretrained` + elif (tp_size := getattr(self.model, "_tp_size", 0)) is not None and tp_size > 1: + self._save(output_dir) + elif self.is_fsdp_enabled: + if "FULL_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type): + state_dict = self.accelerator.get_state_dict(self.model) + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) + elif self.is_deepspeed_enabled: + try: + state_dict = self.accelerator.get_state_dict(self.deepspeed) + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) + except ValueError: + logger.warning( + " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" + " zero_to_fp32.py to recover weights" + ) + if self.args.should_save: + self._save(output_dir, state_dict={}) + # remove the dummy state_dict + remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) + self.model_wrapped.save_checkpoint(output_dir) + + elif self.args.should_save: + self._save(output_dir) + + # Push to the Hub when `save_model` is called by the user. + if self.args.push_to_hub and not _internal_call: + self.push_to_hub(commit_message="Model save", revision=self.args.hub_revision) + + def _save_tpu(self, output_dir: Optional[str] = None): + output_dir = output_dir if output_dir is not None else self.args.output_dir + + logger.info(f"Saving model checkpoint to {output_dir}") + model = self.model + xm.mark_step() + + if xm.is_master_ordinal(local=False): + os.makedirs(output_dir, exist_ok=True) + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + supported_classes = (PushToHubMixin,) + xm.rendezvous("saving_checkpoint") + if self.is_fsdp_xla_v1_enabled: + ckpt = { + "model": model.state_dict(), + "shard_metadata": model.get_shard_metadata(), + } + ckpt_path = os.path.join( + output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{WEIGHTS_NAME}" + ) + # All ranks save sharded checkpoint + xm.save(ckpt, ckpt_path, master_only=False) + # Make sure all ranks have saved checkpoints + xm.rendezvous("save_full_checkpoints") + # Master save full checkpoint + if self.args.should_save: + from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints + + full_state_dict, _ = consolidate_sharded_model_checkpoints( + ckpt_prefix=os.path.join(output_dir, ""), + ckpt_suffix=f"rank*-of-*-{WEIGHTS_NAME}", + save_model=False, + ) + model = model.module.module + unwrapped_model = self.accelerator.unwrap_model(model) + if isinstance(unwrapped_model, supported_classes): + unwrapped_model.save_pretrained( + output_dir, + state_dict=full_state_dict, + save_function=xm.save, + safe_serialization=self.args.save_safetensors, + ) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + xm.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + elif not isinstance(model, supported_classes): + if isinstance(self.accelerator.unwrap_model(model), supported_classes): + self.accelerator.unwrap_model(model).save_pretrained( + output_dir, + is_main_process=self.args.should_save, + state_dict=xm._maybe_convert_to_cpu(model.state_dict()), + save_function=xm.save, + safe_serialization=self.args.save_safetensors, + ) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + state_dict = xm._maybe_convert_to_cpu(model.state_dict()) + xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + model.save_pretrained( + output_dir, + is_main_process=self.args.should_save, + save_function=xm.save, + safe_serialization=self.args.save_safetensors, + state_dict=xm._maybe_convert_to_cpu(model.state_dict()), + ) + if self.processing_class is not None and self.args.should_save: + self.processing_class.save_pretrained(output_dir) + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + # If we are executing this function, we are the process zero, so we don't check for that. + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving model checkpoint to {output_dir}") + + supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + if not isinstance(self.model, supported_classes): + if state_dict is None: + state_dict = self.model.state_dict() + + if isinstance(self.accelerator.unwrap_model(self.model, keep_torch_compile=False), supported_classes): + self.accelerator.unwrap_model(self.model, keep_torch_compile=False).save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + if self.args.save_safetensors: + safetensors.torch.save_file( + state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"} + ) + else: + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + self.model.save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) + + if self.processing_class is not None: + self.processing_class.save_pretrained(output_dir) + elif ( + self.data_collator is not None + and hasattr(self.data_collator, "tokenizer") + and self.data_collator.tokenizer is not None + ): + logger.info("Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`") + self.data_collator.tokenizer.save_pretrained(output_dir) + + # Good practice: save your training arguments together with the trained model + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + def store_flos(self): + # Storing the number of floating-point operations that went into the model + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + self.state.total_flos += ( + distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item() + ) + self.current_flos = 0 + else: + self.state.total_flos += self.current_flos + self.current_flos = 0 + + def _sorted_checkpoints( + self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False + ) -> list[str]: + ordering_and_checkpoint_path = [] + + glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)] + + for path in glob_checkpoints: + if use_mtime: + ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) + else: + regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) + if regex_match is not None and regex_match.groups() is not None: + ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) + + checkpoints_sorted = sorted(ordering_and_checkpoint_path) + # mtime is not reliable on all filesystems, especially on some fuse fs in cloud environments + # so we check if the mtime is fake and fallback to numerical ordering if needed + if use_mtime and len(ordering_and_checkpoint_path) > 1: + mtime_diff = checkpoints_sorted[-1][0] - checkpoints_sorted[0][0] + if mtime_diff < 1.0: # less than 1 second, which is almost impossible when mtime works fine + warnings.warn("mtime may not be reliable on this filesystem, falling back to numerical ordering") + return self._sorted_checkpoints( + use_mtime=False, output_dir=output_dir, checkpoint_prefix=checkpoint_prefix + ) + checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] + + # Make sure we don't delete the best model. + if ( + self.state.best_model_checkpoint is not None + and str(Path(self.state.best_model_checkpoint)) in checkpoints_sorted + ): + best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint))) + for i in range(best_model_index, len(checkpoints_sorted) - 2): + checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i] + return checkpoints_sorted + + def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None: + if self.args.save_total_limit is None or self.args.save_total_limit <= 0: + return + + # Check if we should delete older checkpoint(s) + checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir) + if len(checkpoints_sorted) <= self.args.save_total_limit: + return + + # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which + # we don't do to allow resuming. + save_total_limit = self.args.save_total_limit + if ( + self.state.best_model_checkpoint is not None + and self.args.save_total_limit == 1 + and checkpoints_sorted[-1] != self.state.best_model_checkpoint + ): + save_total_limit = 2 + + number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit) + checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] + for checkpoint in checkpoints_to_be_deleted: + logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") + shutil.rmtree(checkpoint, ignore_errors=True) + + def evaluate( + self, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, + ignore_keys: Optional[list[str]] = None, + metric_key_prefix: str = "eval", + ) -> dict[str, float]: + """ + Run evaluation and returns metrics. + + The calling script will be responsible for providing a method to compute metrics, as they are task-dependent + (pass it to the init `compute_metrics` argument). + + You can also subclass and override this method to inject custom behavior. + + Args: + eval_dataset (Union[`Dataset`, dict[str, `Dataset`]), *optional*): + Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns + not accepted by the `model.forward()` method are automatically removed. If it is a dictionary, it will + evaluate on each dataset, prepending the dictionary key to the metric name. Datasets must implement the + `__len__` method. + + + + If you pass a dictionary with names of datasets as keys and datasets as values, evaluate will run + separate evaluations on each dataset. This can be useful to monitor how training affects other + datasets or simply to get a more fine-grained evaluation. + When used with `load_best_model_at_end`, make sure `metric_for_best_model` references exactly one + of the datasets. If you, for example, pass in `{"data1": data1, "data2": data2}` for two datasets + `data1` and `data2`, you could specify `metric_for_best_model="eval_data1_loss"` for using the + loss on `data1` and `metric_for_best_model="eval_data2_loss"` for the loss on `data2`. + + + + ignore_keys (`list[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"eval"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "eval_bleu" if the prefix is "eval" (default) + + Returns: + A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The + dictionary also contains the epoch number which comes from the training state. + """ + # handle multiple eval datasets + override = eval_dataset is not None + eval_dataset = eval_dataset if override else self.eval_dataset + if isinstance(eval_dataset, dict): + metrics = {} + for eval_dataset_name, _eval_dataset in eval_dataset.items(): + dataset_metrics = self.evaluate( + eval_dataset=_eval_dataset if override else eval_dataset_name, + ignore_keys=ignore_keys, + metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}", + ) + metrics.update(dataset_metrics) + return metrics + + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + eval_dataloader = self.get_eval_dataloader(eval_dataset) + if self.is_fsdp_xla_v2_enabled: + eval_dataloader = tpu_spmd_dataloader(eval_dataloader) + + start_time = time.time() + + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + output = eval_loop( + eval_dataloader, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if self.compute_metrics is None else None, + ignore_keys=ignore_keys, + metric_key_prefix=metric_key_prefix, + ) + + total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] + if f"{metric_key_prefix}_model_preparation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"] + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) + + self.log(output.metrics) + + if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) + + self._memory_tracker.stop_and_update_metrics(output.metrics) + + return output.metrics + + def predict( + self, test_dataset: Dataset, ignore_keys: Optional[list[str]] = None, metric_key_prefix: str = "test" + ) -> PredictionOutput: + """ + Run prediction and returns predictions and potential metrics. + + Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method + will also return metrics, like in `evaluate()`. + + Args: + test_dataset (`Dataset`): + Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the + `model.forward()` method are automatically removed. Has to implement the method `__len__` + ignore_keys (`list[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"test"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "test_bleu" if the prefix is "test" (default) + + + + If your predictions or labels have different sequence length (for instance because you're doing dynamic padding + in a token classification task) the predictions will be padded (on the right) to allow for concatenation into + one array. The padding index is -100. + + + + Returns: *NamedTuple* A namedtuple with the following keys: + + - predictions (`np.ndarray`): The predictions on `test_dataset`. + - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). + - metrics (`dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained + labels). + """ + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + test_dataloader = self.get_test_dataloader(test_dataset) + start_time = time.time() + + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + output = eval_loop( + test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix + ) + total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] + if f"{metric_key_prefix}_model_preparation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"] + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) + + self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics) + self._memory_tracker.stop_and_update_metrics(output.metrics) + + return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[list[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + args = self.args + + prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only + + # if eval is called w/o train, handle model prep here + if self.is_deepspeed_enabled and self.deepspeed is None: + _, _ = deepspeed_init(self, num_training_steps=0, inference=True) + + model = self._wrap_model(self.model, training=False, dataloader=dataloader) + + if len(self.accelerator._models) == 0 and model is self.model: + start_time = time.time() + model = ( + self.accelerator.prepare(model) + if self.is_deepspeed_enabled + or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8" and not self.args.torch_compile) + else self.accelerator.prepare_model(model, evaluation_mode=True) + ) + self.model_preparation_time = round(time.time() - start_time, 4) + + if self.is_fsdp_enabled: + self.model = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) + + batch_size = self.args.eval_batch_size + + logger.info(f"\n***** Running {description} *****") + if has_length(dataloader): + logger.info(f" Num examples = {self.num_examples(dataloader)}") + else: + logger.info(" Num examples: Unknown") + logger.info(f" Batch size = {batch_size}") + + if hasattr(model, "eval") and callable(model.eval): + model.eval() + if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval): + self.optimizer.eval() + + self.callback_handler.eval_dataloader = dataloader + # Do this before wrapping. + eval_dataset = getattr(dataloader, "dataset", None) + + if args.past_index >= 0: + self._past = None + + # Initialize containers + all_losses = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + all_preds = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + + metrics = None + eval_set_kwargs = {} + + # Will be useful when we have an iterable dataset so don't know its length. + observed_num_examples = 0 + + # Main evaluation loop + for step, inputs in enumerate(dataloader): + # Update the observed num examples + observed_batch_size = find_batch_size(inputs) + if observed_batch_size is not None: + observed_num_examples += observed_batch_size + # For batch samplers, batch_size is not known by the dataloader in advance. + if batch_size is None: + batch_size = observed_batch_size + + # Prediction step + losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) + main_input_name = getattr(self.model, "main_input_name", "input_ids") + inputs_decode = ( + self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None + ) + + if is_torch_xla_available(): + xm.mark_step() + + # Update containers + if losses is not None: + losses = self.gather_function(losses.repeat(batch_size)) + all_losses.add(losses) + if inputs_decode is not None: + inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) + inputs_decode = self.gather_function(inputs_decode) + if not self.args.batch_eval_metrics or description == "Prediction": + all_inputs.add(inputs_decode) + if labels is not None: + # Pad labels here, preparing for preprocess_logits_for_metrics in next logits block. + labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) + if logits is not None: + logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) + if self.preprocess_logits_for_metrics is not None: + logits = self.preprocess_logits_for_metrics(logits, labels) + logits = self.gather_function(logits) + if not self.args.batch_eval_metrics or description == "Prediction": + all_preds.add(logits) + if labels is not None: + labels = self.gather_function(labels) + if not self.args.batch_eval_metrics or description == "Prediction": + all_labels.add(labels) + + self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) + + if self.args.batch_eval_metrics: + if self.compute_metrics is not None and logits is not None and labels is not None: + is_last_step = self.accelerator.gradient_state.end_of_dataloader + batch_kwargs = {} + batch_kwargs["losses"] = losses if "loss" in args.include_for_metrics else None + batch_kwargs["inputs"] = inputs if "inputs" in args.include_for_metrics else None + metrics = self.compute_metrics( + EvalPrediction(predictions=logits, label_ids=labels, **batch_kwargs), + compute_result=is_last_step, + ) + + del losses, logits, labels, inputs + torch.cuda.empty_cache() + + # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. + elif args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: + all_losses.to_cpu_and_numpy() + all_preds.to_cpu_and_numpy() + all_labels.to_cpu_and_numpy() + all_inputs.to_cpu_and_numpy() + + del losses, logits, labels, inputs + torch.cuda.empty_cache() + + # After all calls to `.gather_function`, reset to `gather_for_metrics`: + self.gather_function = self.accelerator.gather_for_metrics + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of the evaluation loop + delattr(self, "_past") + + # Gather all remaining tensors and put them back on the CPU + all_losses = all_losses.get_arrays() + all_preds = all_preds.get_arrays() + all_labels = all_labels.get_arrays() + all_inputs = all_inputs.get_arrays() + + # Number of samples + if has_length(eval_dataset): + num_samples = len(eval_dataset) + # The instance check is weird and does not actually check for the type, but whether the dataset has the right + # methods. Therefore we need to make sure it also has the attribute. + elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0: + num_samples = eval_dataset.num_examples + else: + if has_length(dataloader): + num_samples = self.num_examples(dataloader) + else: # both len(dataloader.dataset) and len(dataloader) fail + num_samples = observed_num_examples + if num_samples == 0 and observed_num_examples > 0: + num_samples = observed_num_examples + + # Metrics! + if ( + self.compute_metrics is not None + and all_preds is not None + and all_labels is not None + and not self.args.batch_eval_metrics + ): + eval_set_kwargs["losses"] = all_losses if "loss" in args.include_for_metrics else None + eval_set_kwargs["inputs"] = all_inputs if "inputs" in args.include_for_metrics else None + metrics = self.compute_metrics( + EvalPrediction(predictions=all_preds, label_ids=all_labels, **eval_set_kwargs) + ) + elif metrics is None: + metrics = {} + + # To be JSON-serializable, we need to remove numpy types or zero-d tensors + metrics = denumpify_detensorize(metrics) + + if isinstance(all_losses, list) and all_losses: + metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item() + elif isinstance(all_losses, np.ndarray): + metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() + if hasattr(self, "jit_compilation_time"): + metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time + if hasattr(self, "model_preparation_time"): + metrics[f"{metric_key_prefix}_model_preparation_time"] = self.model_preparation_time + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples) + + def _nested_gather(self, tensors, name=None): + """ + Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before + concatenating them to `gathered` + """ + if tensors is None: + return + if is_torch_xla_available(): + if name is None: + name = "nested_gather" + tensors = nested_xla_mesh_reduce(tensors, name) + elif is_sagemaker_mp_enabled(): + tensors = smp_gather(tensors) + elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or ( + self.args.distributed_state is None and self.args.local_rank != -1 + ): + tensors = distributed_concat(tensors) + return tensors + + def prediction_step( + self, + model: nn.Module, + inputs: dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[list[str]] = None, + ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform an evaluation step on `model` using `inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to evaluate. + inputs (`dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (`bool`): + Whether or not to return the loss only. + ignore_keys (`list[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + + Return: + tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, + logits and labels (each being optional). + """ + has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) + # For CLIP-like models capable of returning loss values. + # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` + # is `True` in `model.forward`. + return_loss = inputs.get("return_loss") + if return_loss is None: + return_loss = self.can_return_loss + loss_without_labels = len(self.label_names) == 0 and return_loss + + inputs = self._prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", ["past_key_values"]) + else: + ignore_keys = [] + + # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. + if has_labels or loss_without_labels: + labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) + if len(labels) == 1: + labels = labels[0] + else: + labels = None + + with torch.no_grad(): + if is_sagemaker_mp_enabled(): + raw_outputs = smp_forward_only(model, inputs) + if has_labels or loss_without_labels: + if isinstance(raw_outputs, dict): + loss_mb = raw_outputs["loss"] + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) + else: + loss_mb = raw_outputs[0] + logits_mb = raw_outputs[1:] + + loss = loss_mb.reduce_mean().detach().cpu() + logits = smp_nested_concat(logits_mb) + else: + loss = None + if isinstance(raw_outputs, dict): + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys) + else: + logits_mb = raw_outputs + logits = smp_nested_concat(logits_mb) + else: + if has_labels or loss_without_labels: + with self.compute_loss_context_manager(): + num_items_in_batch = self._get_num_items_in_batch([inputs], self.args.device) + loss, outputs = self.compute_loss( + model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch + ) + loss = loss.detach().mean() + + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) + else: + logits = outputs[1:] + else: + loss = None + with self.compute_loss_context_manager(): + outputs = model(**inputs) + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) + else: + logits = outputs + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index - 1] + + if prediction_loss_only: + return (loss, None, None) + + logits = nested_detach(logits) + if len(logits) == 1: + logits = logits[0] + + return (loss, logits, labels) + + def floating_point_ops(self, inputs: dict[str, Union[torch.Tensor, Any]]): + """ + For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point + operations for every backward + forward pass. If using another model, either implement such a method in the + model or subclass and override this method. + + Args: + inputs (`dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + Returns: + `int`: The number of floating-point operations. + """ + if hasattr(self.model, "floating_point_ops"): + return self.model.floating_point_ops(inputs) + else: + return 0 + + def init_hf_repo(self, token: Optional[str] = None): + """ + Initializes a git repo in `self.args.hub_model_id`. + """ + # Only on process zero + if not self.is_world_process_zero(): + return + + if self.args.hub_model_id is None: + repo_name = Path(self.args.output_dir).absolute().name + else: + repo_name = self.args.hub_model_id + + token = token if token is not None else self.args.hub_token + repo_url = create_repo(repo_name, token=token, private=self.args.hub_private_repo, exist_ok=True) + self.hub_model_id = repo_url.repo_id + self.push_in_progress = None + + def create_model_card( + self, + language: Optional[str] = None, + license: Optional[str] = None, + tags: Union[str, list[str], None] = None, + model_name: Optional[str] = None, + finetuned_from: Optional[str] = None, + tasks: Union[str, list[str], None] = None, + dataset_tags: Union[str, list[str], None] = None, + dataset: Union[str, list[str], None] = None, + dataset_args: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + language (`str`, *optional*): + The language of the model (if applicable) + license (`str`, *optional*): + The license of the model. Will default to the license of the pretrained model used, if the original + model given to the `Trainer` comes from a repo on the Hub. + tags (`str` or `list[str]`, *optional*): + Some tags to be included in the metadata of the model card. + model_name (`str`, *optional*): + The name of the model. + finetuned_from (`str`, *optional*): + The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo + of the original model given to the `Trainer` (if it comes from the Hub). + tasks (`str` or `list[str]`, *optional*): + One or several task identifiers, to be included in the metadata of the model card. + dataset_tags (`str` or `list[str]`, *optional*): + One or several dataset tags, to be included in the metadata of the model card. + dataset (`str` or `list[str]`, *optional*): + One or several dataset identifiers, to be included in the metadata of the model card. + dataset_args (`str` or `list[str]`, *optional*): + One or several dataset arguments, to be included in the metadata of the model card. + """ + if not self.is_world_process_zero(): + return + + model_card_filepath = os.path.join(self.args.output_dir, "README.md") + is_peft_library = False + if os.path.exists(model_card_filepath): + library_name = ModelCard.load(model_card_filepath).data.get("library_name") + is_peft_library = library_name == "peft" + + # Append existing tags in `tags` + existing_tags = ModelCard.load(model_card_filepath).data.tags + if tags is not None and existing_tags is not None: + if isinstance(tags, str): + tags = [tags] + for tag in existing_tags: + if tag not in tags: + tags.append(tag) + + training_summary = TrainingSummary.from_trainer( + self, + language=language, + license=license, + tags=tags, + model_name=model_name, + finetuned_from=finetuned_from, + tasks=tasks, + dataset_tags=dataset_tags, + dataset=dataset, + dataset_args=dataset_args, + ) + model_card = training_summary.to_model_card() + with open(model_card_filepath, "w") as f: + f.write(model_card) + + if is_peft_library: + self.accelerator.unwrap_model(self.model).create_or_update_model_card(self.args.output_dir) + + def _push_from_checkpoint(self, checkpoint_folder): + # Only push from one node. + if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END: + return + # If we haven't finished the last push, we don't do this one unless args.hub_always_push=True. + if not self.args.hub_always_push and self.push_in_progress is not None and not self.push_in_progress.is_done(): + return + + output_dir = self.args.output_dir + # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder + modeling_files = [CONFIG_NAME, GENERATION_CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME] + # Add sharded checkpoints if we have an index + for index_file in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]: + index_path = os.path.join(checkpoint_folder, index_file) + if os.path.isfile(index_path): + modeling_files.append(index_file) + with open(index_path) as f: + index = json.loads(f.read()) + shard_files = list(set(index["weight_map"].values())) + modeling_files.extend(shard_files) + if is_peft_available(): + modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME]) + for modeling_file in modeling_files: + if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)): + shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file)) + # Saving the processing class is fast and we don't know how many files it may have spawned, so we resave it to be sure. + if self.processing_class is not None: + self.processing_class.save_pretrained(output_dir) + # Same for the training arguments + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + if self.args.save_strategy == SaveStrategy.STEPS: + commit_message = f"Training in progress, step {self.state.global_step}" + else: + commit_message = f"Training in progress, epoch {int(self.state.epoch)}" + + model_push_job = upload_folder( + repo_id=self.hub_model_id, + folder_path=output_dir, + commit_message=commit_message, + token=self.args.hub_token, + run_as_future=True, + ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"], + revision=self.args.hub_revision, + ) + + push_jobs = [model_push_job] + + if self.args.hub_strategy in [HubStrategy.CHECKPOINT, HubStrategy.ALL_CHECKPOINTS]: + path_in_repo = ( + "last-checkpoint" if self.args.hub_strategy == HubStrategy.CHECKPOINT else Path(checkpoint_folder).name + ) + checkpoint_push = upload_folder( + repo_id=self.hub_model_id, + folder_path=checkpoint_folder, + path_in_repo=path_in_repo, + commit_message=commit_message + ", checkpoint", + token=self.args.hub_token, + run_as_future=True, + revision=self.args.hub_revision, + ) + push_jobs.append(checkpoint_push) + + if self.push_in_progress is None or self.push_in_progress.is_done(): + self.push_in_progress = PushInProgress(push_jobs) + else: + self.push_in_progress.jobs.extend(push_jobs) + + def _finish_current_push(self): + if not hasattr(self, "push_in_progress"): + return + if self.push_in_progress is not None and not self.push_in_progress.is_done(): + logger.info("Waiting for the current checkpoint push to be finished, this might take a couple of minutes.") + self.push_in_progress.wait_until_done() + + def push_to_hub( + self, + commit_message: Optional[str] = "End of training", + blocking: bool = True, + token: Optional[str] = None, + revision: Optional[str] = None, + **kwargs, + ) -> str: + """ + Upload `self.model` and `self.processing_class` to the 🤗 model hub on the repo `self.args.hub_model_id`. + + Parameters: + commit_message (`str`, *optional*, defaults to `"End of training"`): + Message to commit while pushing. + blocking (`bool`, *optional*, defaults to `True`): + Whether the function should return only when the `git push` has finished. + token (`str`, *optional*, defaults to `None`): + Token with write permission to overwrite Trainer's original args. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the "main" branch. + kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments passed along to [`~Trainer.create_model_card`]. + + Returns: + The URL of the repository where the model was pushed if `blocking=False`, or a `Future` object tracking the + progress of the commit if `blocking=True`. + """ + model_name = kwargs.pop("model_name", None) + if model_name is None and self.args.should_save: + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + token = token if token is not None else self.args.hub_token + + # In case the user calls this method with args.push_to_hub = False + if self.hub_model_id is None: + self.init_hf_repo(token=token) + + # Needs to be executed on all processes for TPU training, but will only save on the processed determined by + # self.args.should_save. + self.save_model(_internal_call=True) + + # Only push from one node. + if not self.is_world_process_zero(): + return + + # Add additional tags in the case the model has already some tags and users pass + # "tags" argument to `push_to_hub` so that trainer automatically handles internal tags + # from all models since Trainer does not call `model.push_to_hub`. + if getattr(self.model, "model_tags", None) is not None: + if "tags" not in kwargs: + kwargs["tags"] = [] + + # If it is a string, convert it to a list + if isinstance(kwargs["tags"], str): + kwargs["tags"] = [kwargs["tags"]] + + for model_tag in self.model.model_tags: + if model_tag not in kwargs["tags"]: + kwargs["tags"].append(model_tag) + + self.create_model_card(model_name=model_name, **kwargs) + + if revision is None: + revision = self.args.hub_revision + + # Wait for the current upload to be finished. + self._finish_current_push() + + return upload_folder( + repo_id=self.hub_model_id, + folder_path=self.args.output_dir, + commit_message=commit_message, + token=token, + run_as_future=not blocking, + ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"], + revision=revision, + ) + + # + # Deprecated code + # + + def prediction_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[list[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + args = self.args + + if not has_length(dataloader): + raise ValueError("dataloader must implement a working __len__") + + prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only + + # if eval is called w/o train, handle model prep here + if self.is_deepspeed_enabled and self.deepspeed is None: + _, _ = deepspeed_init(self, num_training_steps=0, inference=True) + + model = self._wrap_model(self.model, training=False, dataloader=dataloader) + + if len(self.accelerator._models) == 0 and model is self.model: + model = ( + self.accelerator.prepare(model) + if self.is_deepspeed_enabled or self.is_fsdp_enabled + else self.accelerator.prepare_model(model, evaluation_mode=True) + ) + + if self.is_fsdp_enabled: + self.model = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) + + batch_size = ( + dataloader.total_batch_size + if getattr(dataloader, "_is_accelerate_prepared", False) + else dataloader.batch_size + ) + + if batch_size is None: + raise ValueError( + "Batch size cannot be None. Ensure the dataloader has a valid batch_size or total_batch_size." + ) + + num_examples = self.num_examples(dataloader) + logger.info(f"\n***** Running {description} *****") + logger.info(f" Num examples = {num_examples}") + logger.info(f" Batch size = {batch_size}") + + losses_host: Optional[torch.Tensor] = None + preds_host: Union[torch.Tensor, list[torch.Tensor], None] = None + labels_host: Union[torch.Tensor, list[torch.Tensor], None] = None + inputs_host: Union[torch.Tensor, list[torch.Tensor], None] = None + metrics: Optional[dict] = None + eval_set_kwargs: dict = {} + + world_size = max(1, args.world_size) + + eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) + if not prediction_loss_only: + # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass + # a batch size to the sampler) + make_multiple_of = None + if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler): + make_multiple_of = dataloader.sampler.batch_size + preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) + labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) + inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) + + model.eval() + if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval): + self.optimizer.eval() + + if args.past_index >= 0: + self._past = None + + self.callback_handler.eval_dataloader = dataloader + + for step, inputs in enumerate(dataloader): + loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) + main_input_name = getattr(self.model, "main_input_name", "input_ids") + inputs_decode = ( + self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None + ) + + if loss is not None: + losses = loss.repeat(batch_size) + losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) + if logits is not None: + preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) + if labels is not None: + labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) + if inputs_decode is not None: + inputs_host = ( + inputs_decode + if inputs_host is None + else nested_concat(inputs_host, inputs_decode, padding_index=-100) + ) + self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) + + if self.args.batch_eval_metrics: + if self.compute_metrics is not None and preds_host is not None and labels_host is not None: + is_last_step = self.accelerator.gradient_state.end_of_dataloader + batch_kwargs = {} + batch_kwargs["losses"] = losses_host if "loss" in args.include_for_metrics else None + batch_kwargs["inputs"] = inputs_host if "inputs" in args.include_for_metrics else None + metrics = self.compute_metrics( + EvalPrediction(predictions=preds_host, label_ids=labels_host, **batch_kwargs), + compute_result=is_last_step, + ) + + if self.args.batch_eval_metrics or ( + args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0 + ): + # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. + eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) + if not prediction_loss_only: + preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) + labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) + inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) + + # Set back to None to begin a new accumulation + del losses_host, preds_host, labels_host, inputs_host + torch.cuda.empty_cache() + losses_host, preds_host, labels_host, inputs_host = None, None, None, None + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of the evaluation loop + delattr(self, "_past") + + # Gather all remaining tensors and put them back on the CPU + eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) + if not prediction_loss_only: + preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) + labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) + inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) + + eval_loss = eval_losses_gatherer.finalize() + preds = preds_gatherer.finalize() if not prediction_loss_only else None + label_ids = labels_gatherer.finalize() if not prediction_loss_only else None + inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None + + if ( + self.compute_metrics is not None + and preds is not None + and label_ids is not None + and not self.args.batch_eval_metrics + ): + eval_set_kwargs["losses"] = eval_loss if "loss" in args.include_for_metrics else None + eval_set_kwargs["inputs"] = inputs_ids if "inputs" in args.include_for_metrics else None + metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids, **eval_set_kwargs)) + elif metrics is None: + metrics = {} + + # To be JSON-serializable, we need to remove numpy types or zero-d tensors + metrics = denumpify_detensorize(metrics) + + if eval_loss is not None: + metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item() + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples) + + def _gather_and_numpify(self, tensors, name): + """ + Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before + concatenating them to `gathered` + """ + if tensors is None: + return + if is_torch_xla_available(): + tensors = nested_xla_mesh_reduce(tensors, name) + elif is_sagemaker_mp_enabled(): + tensors = smp_gather(tensors) + elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: + tensors = distributed_concat(tensors) + + return nested_numpify(tensors) + + def _add_sm_patterns_to_gitignore(self) -> None: + """Add SageMaker Checkpointing patterns to .gitignore file.""" + # Make sure we only do this on the main process + if not self.is_world_process_zero(): + return + + patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"] + + # Get current .gitignore content + if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")): + with open(os.path.join(self.repo.local_dir, ".gitignore")) as f: + current_content = f.read() + else: + current_content = "" + + # Add the patterns to .gitignore + content = current_content + for pattern in patterns: + if pattern not in content: + if content.endswith("\n"): + content += pattern + else: + content += f"\n{pattern}" + + # Write the .gitignore file if it has changed + if content != current_content: + with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f: + logger.debug(f"Writing .gitignore file. Content: {content}") + f.write(content) + + self.repo.git_add(".gitignore") + + # avoid race condition with git status + time.sleep(0.5) + + if not self.repo.is_repo_clean(): + self.repo.git_commit("Add *.sagemaker patterns to .gitignore.") + self.repo.git_push() + + def create_accelerator_and_postprocess(self): + # We explicitly don't rely on the `Accelerator` to do gradient accumulation + grad_acc_kwargs = {} + if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None: + grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs + + # check if num_steps is attempted to be passed in gradient_accumulation_kwargs + if "num_steps" in grad_acc_kwargs: + if self.args.gradient_accumulation_steps > 1: + # raise because we do not know which setting is intended. + raise ValueError( + "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`" + "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`." + ) + else: + self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"] + + accelerator_config = self.args.accelerator_config.to_dict() + + if is_accelerate_available("0.28.0"): + # Extract dataloader config params from accelerator config + dataloader_params = ["split_batches", "dispatch_batches", "even_batches", "use_seedable_sampler"] + dataloader_config = DataLoaderConfiguration( + **{param: accelerator_config.pop(param) for param in dataloader_params} + ) + if is_accelerate_available("1.1.0"): + dataloader_config.data_seed = self.args.data_seed + + non_blocking = accelerator_config.pop("non_blocking") + if not is_accelerate_available("0.30.0"): + if non_blocking: + raise ImportError( + "`non_blocking` is only supported in accelerate v0.30.0 and above. Please upgrade accelerate to use this feature." + ) + else: + if non_blocking and not self.args.dataloader_pin_memory: + logger.warning( + "`non_blocking` is enabled but `dataloader_pin_memory` is not. For the best performance, it's recommended to enable both." + ) + dataloader_config.non_blocking = non_blocking + # this would have been updated above, no need for it anymore + accelerator_config.pop("gradient_accumulation_kwargs") + + args = { + "deepspeed_plugin": self.args.deepspeed_plugin, + } + + # We defer compatibility checks to accelerator + if self.args.parallelism_config is not None: + if not is_accelerate_available("1.10.1"): + raise ImportError( + "ParallelismConfig requires accelerate v1.10.1 and above. Please upgrade accelerate to use this feature." + ) + + args["parallelism_config"] = self.args.parallelism_config + + if is_accelerate_available("0.28.0"): + args["dataloader_config"] = dataloader_config + else: + args.update(accelerator_config) + # tp is initialized at Accelerator init phase so + # args should be prepared here + if hasattr(self.model, "tp_size") and self.model.tp_size is not None and self.model.tp_size > 1: + self.is_tp_enabled = True + if version.parse(accelerate_version) > version.parse("1.3.0"): + args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.model.tp_size) + else: + raise ValueError("Requires accelerate>1.3.0 to use Tensor Parallelism.") + + # create accelerator object + self.accelerator = Accelerator(**args) + # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag + self.gather_function = self.accelerator.gather_for_metrics + + if "use_gather_object" in inspect.signature(self.gather_function).parameters: + self.gather_function = functools.partial( + self.gather_function, use_gather_object=self.args.eval_use_gather_object + ) + + # deepspeed and accelerate flags covering both trainer args and accelerate launcher + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + self.is_tp_enabled = getattr(self.accelerator.state, "torch_tp_plugin", None) is not None + # post accelerator creation setup + if self.is_fsdp_enabled: + fsdp_plugin = self.accelerator.state.fsdp_plugin + for param in ["limit_all_gathers", "activation_checkpointing"]: + setattr(fsdp_plugin, param, self.args.fsdp_config.get(param, getattr(fsdp_plugin, param))) + if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing: + raise ValueError( + "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg " + "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic " + "when using FSDP." + ) + + if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None: + self.propagate_args_to_deepspeed() + + # `save_only_model` can't be used with DeepSpeed/FSDP along with `load_best_model_at_end` + if ( + self.args.save_only_model + and (self.is_deepspeed_enabled or self.is_fsdp_enabled) + and self.args.load_best_model_at_end + ): + wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP" + raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.") + + # `auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3 + if ( + self.is_deepspeed_enabled + and self.accelerator.state.deepspeed_plugin.zero_stage == 3 + and self.args.auto_find_batch_size + ): + raise ValueError( + "`auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3. Please consider using Zero-2, Zero-1, or FSDP" + ) + if ( + self.args.save_only_model + and self.is_fsdp_enabled + and "SHARDED_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type) + ): + raise ValueError("save_only_model option is not compatible with FSDP state dict type 'SHARDED_STATE_DICT'") + + def propagate_args_to_deepspeed(self, auto_find_batch_size=False): + """ + Sets values in the deepspeed plugin based on the Trainer args + """ + from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig + + ds_plugin = self.accelerator.state.deepspeed_plugin + + ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config) + ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config + ds_plugin.hf_ds_config.trainer_config_process(self.args, auto_find_batch_size) + + def _fsdp_qlora_plugin_updates(self): + if self.is_fsdp_enabled and _is_peft_model(self.model): + from peft import PeftConfig + from peft.utils.other import fsdp_auto_wrap_policy + + if isinstance(self.model.active_peft_config, PeftConfig): + self.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model) + if ( + getattr(self.model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + and self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage.is_floating_point + and version.parse(accelerate_version) > version.parse("0.27.0") + ): + self.accelerator.state.fsdp_plugin.set_mixed_precision( + self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True + ) + + def _get_num_items_in_batch(self, batch_samples: list, device: torch.device) -> Optional[Union[torch.Tensor, int]]: + """ + Counts the number of items in the batches to properly scale the loss. + Args: + batch_samples (`list`): List of batches + device (`torch.device`): The device on which the number of items in the batch should be. + Returns: + None if the number of items in the batch doesn't need to be computed else the number of items in the batch + """ + num_items_in_batch = None + count_num_items_in_batch = ( + len(batch_samples) > 0 + and "labels" in batch_samples[0] + and ( + # num_items_in_batch is passed to model forward + # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3757 + self.model_accepts_loss_kwargs + # num_items_in_batch is passed to compute_loss_func + # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3773 + or self.compute_loss_func is not None + # num_items_in_batch is also verified if (self.model_accepts_loss_kwargs or self.compute_loss_func) + # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3790 + ) + ) + if count_num_items_in_batch: + # For now we don't support object detection + try: + num_items_in_batch = sum((batch["labels"].ne(-100)).sum() for batch in batch_samples) + except (TypeError, AttributeError): + pass + + if num_items_in_batch is not None: + if self.args.average_tokens_across_devices and self.args.world_size >= 1: + num_items_in_batch = self.accelerator.gather(num_items_in_batch.to(device)).sum() + elif self.args.n_gpu >= 1: + # In DP case, if we don't average, we need to divide by the number of gpu. This is the simplest approximation. + # Otherwise, we would have to scatter labels and calculate num_items_in_batch for each gpu. + num_items_in_batch = num_items_in_batch // self.args.n_gpu + + if torch.is_tensor(num_items_in_batch): + num_items_in_batch = num_items_in_batch.to(device) + + if self.args.n_gpu > 1 and num_items_in_batch.dim() == 0: + # In the DataParallel case, convert the scalar tensor into a 2-dim tensor with the same value repeated + num_items_in_batch = num_items_in_batch.unsqueeze(0).expand(self.args.n_gpu, -1) + # Divide by number of devices with the same batch + if pc := getattr(self.accelerator, "parallelism_config", None): + num_items_in_batch = num_items_in_batch // pc.non_data_parallel_size + + return num_items_in_batch + + def get_batch_samples( + self, epoch_iterator: Iterator, num_batches: int, device: torch.device + ) -> tuple[list, Optional[Union[torch.Tensor, int]]]: + """ + Collects a specified number of batches from the epoch iterator and optionally counts the number of items in the batches to properly scale the loss. + """ + batch_samples = [] + + for _ in range(num_batches): + try: + batch_samples.append(next(epoch_iterator)) + except StopIteration: + break + + num_items_in_batch = self._get_num_items_in_batch(batch_samples, device) + return batch_samples, num_items_in_batch + + def set_initial_training_values( + self, args: TrainingArguments, dataloader: DataLoader, total_train_batch_size: int + ): + """ + Calculates and returns the following values: + - `num_train_epochs` + - `num_update_steps_per_epoch` + - `num_examples` + - `num_train_samples` + - `epoch_based` + - `len_dataloader` + - `max_steps` + """ + # Case 1: we rely on `args.max_steps` first + max_steps = args.max_steps + # If max_steps is negative, we use the number of epochs to determine the number of total steps later + epoch_based = max_steps < 0 + len_dataloader = len(dataloader) if has_length(dataloader) else None + + # Case 2: We have a dataloader length and can extrapolate + if len_dataloader is not None: + num_update_steps_per_epoch = max( + len_dataloader // args.gradient_accumulation_steps + + int(len_dataloader % args.gradient_accumulation_steps > 0), + 1, + ) + # Case 3: We have a length but are using epochs, we can extrapolate the number of steps + if epoch_based: + max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) + + # Now we figure out `num_examples`, `num_train_epochs`, and `train_samples` + if len_dataloader: + num_examples = self.num_examples(dataloader) + if args.max_steps > 0: + num_train_epochs = max_steps // num_update_steps_per_epoch + int( + max_steps % num_update_steps_per_epoch > 0 + ) + # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's + # the best we can do. + num_train_samples = max_steps * total_train_batch_size + else: + num_train_epochs = math.ceil(args.num_train_epochs) + num_train_samples = self.num_examples(dataloader) * args.num_train_epochs + elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size + # Setting a very large number of epochs so we go as many times as necessary over the iterator. + num_train_epochs = sys.maxsize + num_update_steps_per_epoch = max_steps + num_examples = total_train_batch_size * args.max_steps + num_train_samples = args.max_steps * total_train_batch_size + else: + raise ValueError( + "args.max_steps must be set to a positive value if dataloader does not have a length, was" + f" {args.max_steps}" + ) + return ( + num_train_epochs, + num_update_steps_per_epoch, + num_examples, + num_train_samples, + epoch_based, + len_dataloader, + max_steps, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/trainer_callback.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/trainer_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..c72bdbb70bcd189582b8b69cd121be6a10c5c4e8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/trainer_callback.py @@ -0,0 +1,785 @@ +# Copyright 2020-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Callbacks to use with the Trainer class and customize the training loop. +""" + +import dataclasses +import json +import math +from dataclasses import dataclass +from typing import Optional, Union + +import numpy as np +from tqdm.auto import tqdm + +from .trainer_utils import HPSearchBackend, IntervalStrategy, SaveStrategy, has_length +from .training_args import TrainingArguments +from .utils import logging + + +logger = logging.get_logger(__name__) + + +@dataclass +class TrainerState: + """ + A class containing the [`Trainer`] inner state that will be saved along the model and optimizer when checkpointing + and passed to the [`TrainerCallback`]. + + + + In all this class, one step is to be understood as one update step. When using gradient accumulation, one update + step may require several forward and backward passes: if you use `gradient_accumulation_steps=n`, then one update + step requires going through *n* batches. + + + + Args: + epoch (`float`, *optional*): + Only set during training, will represent the epoch the training is at (the decimal part being the + percentage of the current epoch completed). + global_step (`int`, *optional*, defaults to 0): + During training, represents the number of update steps completed. + max_steps (`int`, *optional*, defaults to 0): + The number of update steps to do during the current training. + logging_steps (`int`, *optional*, defaults to 500): + Log every X updates steps + eval_steps (`int`, *optional*): + Run an evaluation every X steps. + save_steps (`int`, *optional*, defaults to 500): + Save checkpoint every X updates steps. + train_batch_size (`int`, *optional*): + The batch size for the training dataloader. Only needed when + `auto_find_batch_size` has been used. + num_input_tokens_seen (`int`, *optional*, defaults to 0): + When tracking the inputs tokens, the number of tokens seen during training (number of input tokens, not the + number of prediction tokens). + total_flos (`float`, *optional*, defaults to 0): + The total number of floating operations done by the model since the beginning of training (stored as floats + to avoid overflow). + log_history (`list[dict[str, float]]`, *optional*): + The list of logs done since the beginning of training. + best_metric (`float`, *optional*): + When tracking the best model, the value of the best metric encountered so far. + best_global_step (`int`, *optional*): + When tracking the best model, the step at which the best metric was encountered. + Used for setting `best_model_checkpoint`. + best_model_checkpoint (`str`, *optional*): + When tracking the best model, the value of the name of the checkpoint for the best model encountered so + far. + is_local_process_zero (`bool`, *optional*, defaults to `True`): + Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on + several machines) main process. + is_world_process_zero (`bool`, *optional*, defaults to `True`): + Whether or not this process is the global main process (when training in a distributed fashion on several + machines, this is only going to be `True` for one process). + is_hyper_param_search (`bool`, *optional*, defaults to `False`): + Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will + impact the way data will be logged in TensorBoard. + stateful_callbacks (`list[StatefulTrainerCallback]`, *optional*): + Callbacks attached to the `Trainer` that should have their states be saved or restored. + Relevant callbacks should implement a `state` and `from_state` function. + """ + + epoch: Optional[float] = None + global_step: int = 0 + max_steps: int = 0 + logging_steps: int = 500 + eval_steps: int = 500 + save_steps: int = 500 + train_batch_size: Optional[int] = None + num_train_epochs: int = 0 + num_input_tokens_seen: int = 0 + total_flos: float = 0 + log_history: list[dict[str, float]] = None + best_metric: Optional[float] = None + best_global_step: Optional[int] = None + best_model_checkpoint: Optional[str] = None + is_local_process_zero: bool = True + is_world_process_zero: bool = True + is_hyper_param_search: bool = False + trial_name: Optional[str] = None + trial_params: Optional[dict[str, Union[str, float, int, bool]]] = None + stateful_callbacks: Optional[list["TrainerCallback"]] = None + + def __post_init__(self): + if self.log_history is None: + self.log_history = [] + if self.stateful_callbacks is None: + self.stateful_callbacks = {} + elif isinstance(self.stateful_callbacks, dict): + # We are loading the callbacks in from the state file, no need to process them + pass + else: + # Saveable callbacks get stored as dict of kwargs + stateful_callbacks = {} + for callback in self.stateful_callbacks: + if not isinstance(callback, (ExportableState)): + raise TypeError( + f"All callbacks passed to be saved must inherit `ExportableState`, but received {type(callback)}" + ) + name = callback.__class__.__name__ + if name in stateful_callbacks: + # We can have multiple versions of the same callback + # if so, we store them as a list of states to restore + if not isinstance(stateful_callbacks[name], list): + stateful_callbacks[name] = [stateful_callbacks[name]] + stateful_callbacks[name].append(callback.state()) + else: + stateful_callbacks[name] = callback.state() + self.stateful_callbacks = stateful_callbacks + + def save_to_json(self, json_path: str): + """Save the content of this instance in JSON format inside `json_path`.""" + json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n" + with open(json_path, "w", encoding="utf-8") as f: + f.write(json_string) + + @classmethod + def load_from_json(cls, json_path: str): + """Create an instance from the content of `json_path`.""" + with open(json_path, encoding="utf-8") as f: + text = f.read() + return cls(**json.loads(text)) + + def compute_steps(self, args, max_steps): + """ + Calculates and stores the absolute value for logging, + eval, and save steps based on if it was a proportion + or not. + """ + for step_kind in ("logging", "eval", "save"): + num_steps = getattr(args, f"{step_kind}_steps") + if num_steps is not None: + if num_steps < 1: + num_steps = math.ceil(max_steps * num_steps) + setattr(self, f"{step_kind}_steps", num_steps) + + def init_training_references(self, trainer, max_steps, num_train_epochs, trial): + """ + Stores the initial training references needed in `self` + """ + if trainer.hp_name is not None and trainer._trial is not None: + # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial + # parameter to Train when using DDP. + self.trial_name = trainer.hp_name(trainer._trial) + self.trial_params = None + if trial is not None: + from transformers.integrations import hp_params + + assignments = trial.assignments if trainer.hp_search_backend == HPSearchBackend.SIGOPT else trial + self.trial_params = hp_params(assignments) + + self.max_steps = max_steps + self.num_train_epochs = num_train_epochs + self.is_local_process_zero = trainer.is_local_process_zero() + self.is_world_process_zero = trainer.is_world_process_zero() + + +class ExportableState: + """ + A class for objects that include the ability to have its state + be saved during `Trainer._save_checkpoint` and loaded back in during + `Trainer._load_from_checkpoint`. + + These must implement a `state` function that gets called during the respective + Trainer function call. It should only include parameters and attributes needed to + recreate the state at a particular time, to avoid utilizing pickle/maintain standard + file IO writing. + + Example: + + ```python + class EarlyStoppingCallback(TrainerCallback, ExportableState): + def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0): + self.early_stopping_patience = early_stopping_patience + self.early_stopping_threshold = early_stopping_threshold + # early_stopping_patience_counter denotes the number of times validation metrics failed to improve. + self.early_stopping_patience_counter = 0 + + def state(self) -> dict: + return { + "args": { + "early_stopping_patience": self.early_stopping_patience, + "early_stopping_threshold": self.early_stopping_threshold, + }, + "attributes": { + "early_stopping_patience_counter": self.early_stopping_patience_counter, + } + } + ```""" + + def state(self) -> dict: + raise NotImplementedError("You must implement a `state` function to utilize this class.") + + @classmethod + def from_state(cls, state): + instance = cls(**state["args"]) + for k, v in state["attributes"].items(): + setattr(instance, k, v) + return instance + + +@dataclass +class TrainerControl(ExportableState): + """ + A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some + switches in the training loop. + + Args: + should_training_stop (`bool`, *optional*, defaults to `False`): + Whether or not the training should be interrupted. + + If `True`, this variable will not be set back to `False`. The training will just stop. + should_epoch_stop (`bool`, *optional*, defaults to `False`): + Whether or not the current epoch should be interrupted. + + If `True`, this variable will be set back to `False` at the beginning of the next epoch. + should_save (`bool`, *optional*, defaults to `False`): + Whether or not the model should be saved at this step. + + If `True`, this variable will be set back to `False` at the beginning of the next step. + should_evaluate (`bool`, *optional*, defaults to `False`): + Whether or not the model should be evaluated at this step. + + If `True`, this variable will be set back to `False` at the beginning of the next step. + should_log (`bool`, *optional*, defaults to `False`): + Whether or not the logs should be reported at this step. + + If `True`, this variable will be set back to `False` at the beginning of the next step. + """ + + should_training_stop: bool = False + should_epoch_stop: bool = False + should_save: bool = False + should_evaluate: bool = False + should_log: bool = False + + def _new_training(self): + """Internal method that resets the variable for a new training.""" + self.should_training_stop = False + + def _new_epoch(self): + """Internal method that resets the variable for a new epoch.""" + self.should_epoch_stop = False + + def _new_step(self): + """Internal method that resets the variable for a new step.""" + self.should_save = False + self.should_evaluate = False + self.should_log = False + + def state(self) -> dict: + return { + "args": { + "should_training_stop": self.should_training_stop, + "should_epoch_stop": self.should_epoch_stop, + "should_save": self.should_save, + "should_evaluate": self.should_evaluate, + "should_log": self.should_log, + }, + "attributes": {}, + } + + +class TrainerCallback: + # no-format + """ + A class for objects that will inspect the state of the training loop at some events and take some decisions. At + each of those events the following arguments are available: + + Args: + args ([`TrainingArguments`]): + The training arguments used to instantiate the [`Trainer`]. + state ([`TrainerState`]): + The current state of the [`Trainer`]. + control ([`TrainerControl`]): + The object that is returned to the [`Trainer`] and can be used to make some decisions. + model ([`PreTrainedModel`] or `torch.nn.Module`): + The model being trained. + tokenizer ([`PreTrainedTokenizer`]): + The tokenizer used for encoding the data. This is deprecated in favour of `processing_class`. + processing_class ([`PreTrainedTokenizer` or `BaseImageProcessor` or `ProcessorMixin` or `FeatureExtractionMixin`]): + The processing class used for encoding the data. Can be a tokenizer, a processor, an image processor or a feature extractor. + optimizer (`torch.optim.Optimizer`): + The optimizer used for the training steps. + lr_scheduler (`torch.optim.lr_scheduler.LambdaLR`): + The scheduler used for setting the learning rate. + train_dataloader (`torch.utils.data.DataLoader`, *optional*): + The current dataloader used for training. + eval_dataloader (`torch.utils.data.DataLoader`, *optional*): + The current dataloader used for evaluation. + metrics (`dict[str, float]`): + The metrics computed by the last evaluation phase. + + Those are only accessible in the event `on_evaluate`. + logs (`dict[str, float]`): + The values to log. + + Those are only accessible in the event `on_log`. + + The `control` object is the only one that can be changed by the callback, in which case the event that changes it + should return the modified version. + + The argument `args`, `state` and `control` are positionals for all events, all the others are grouped in `kwargs`. + You can unpack the ones you need in the signature of the event using them. As an example, see the code of the + simple [`~transformers.PrinterCallback`]. + + Example: + + ```python + class PrinterCallback(TrainerCallback): + def on_log(self, args, state, control, logs=None, **kwargs): + _ = logs.pop("total_flos", None) + if state.is_local_process_zero: + print(logs) + ```""" + + def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the end of the initialization of the [`Trainer`]. + """ + pass + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the beginning of training. + """ + pass + + def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the end of training. + """ + pass + + def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the beginning of an epoch. + """ + pass + + def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the end of an epoch. + """ + pass + + def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the beginning of a training step. If using gradient accumulation, one training step might take + several inputs. + """ + pass + + def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called before the optimizer step but after gradient clipping. Useful for monitoring gradients. + """ + pass + + def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients. + """ + pass + + def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the end of an substep during gradient accumulation. + """ + pass + + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the end of a training step. If using gradient accumulation, one training step might take + several inputs. + """ + pass + + def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called after an evaluation phase. + """ + pass + + def on_predict(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs): + """ + Event called after a successful prediction. + """ + pass + + def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called after a checkpoint save. + """ + pass + + def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called after logging the last logs. + """ + pass + + def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called after a prediction step. + """ + pass + + +class CallbackHandler(TrainerCallback): + """Internal class that just calls the list of callbacks in order.""" + + def __init__(self, callbacks, model, processing_class, optimizer, lr_scheduler): + self.callbacks = [] + for cb in callbacks: + self.add_callback(cb) + self.model = model + self.processing_class = processing_class + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.train_dataloader = None + self.eval_dataloader = None + + if not any(isinstance(cb, DefaultFlowCallback) for cb in self.callbacks): + logger.warning( + "The Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You\n" + + "should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list of" + + "callbacks is\n:" + + self.callback_list + ) + + def add_callback(self, callback): + cb = callback() if isinstance(callback, type) else callback + cb_class = callback if isinstance(callback, type) else callback.__class__ + if cb_class in [c.__class__ for c in self.callbacks]: + logger.warning( + f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. The current" + + "list of callbacks is\n:" + + self.callback_list + ) + self.callbacks.append(cb) + + def pop_callback(self, callback): + if isinstance(callback, type): + for cb in self.callbacks: + if isinstance(cb, callback): + self.callbacks.remove(cb) + return cb + else: + for cb in self.callbacks: + if cb == callback: + self.callbacks.remove(cb) + return cb + + def remove_callback(self, callback): + if isinstance(callback, type): + for cb in self.callbacks: + if isinstance(cb, callback): + self.callbacks.remove(cb) + return + else: + self.callbacks.remove(callback) + + @property + def callback_list(self): + return "\n".join(cb.__class__.__name__ for cb in self.callbacks) + + def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_init_end", args, state, control) + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + control.should_training_stop = False + return self.call_event("on_train_begin", args, state, control) + + def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_train_end", args, state, control) + + def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + control.should_epoch_stop = False + return self.call_event("on_epoch_begin", args, state, control) + + def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_epoch_end", args, state, control) + + def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + control.should_log = False + control.should_evaluate = False + control.should_save = False + return self.call_event("on_step_begin", args, state, control) + + def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_pre_optimizer_step", args, state, control) + + def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_optimizer_step", args, state, control) + + def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_substep_end", args, state, control) + + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_step_end", args, state, control) + + def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics): + control.should_evaluate = False + return self.call_event("on_evaluate", args, state, control, metrics=metrics) + + def on_predict(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics): + return self.call_event("on_predict", args, state, control, metrics=metrics) + + def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + control.should_save = False + return self.call_event("on_save", args, state, control) + + def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs): + control.should_log = False + return self.call_event("on_log", args, state, control, logs=logs) + + def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_prediction_step", args, state, control) + + def call_event(self, event, args, state, control, **kwargs): + for callback in self.callbacks: + result = getattr(callback, event)( + args, + state, + control, + model=self.model, + processing_class=self.processing_class, + optimizer=self.optimizer, + lr_scheduler=self.lr_scheduler, + train_dataloader=self.train_dataloader, + eval_dataloader=self.eval_dataloader, + **kwargs, + ) + # A Callback can skip the return of `control` if it doesn't change it. + if result is not None: + control = result + return control + + +class DefaultFlowCallback(TrainerCallback): + """ + A [`TrainerCallback`] that handles the default flow of the training loop for logs, evaluation and checkpoints. + """ + + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + # Log + if state.global_step == 1 and args.logging_first_step: + control.should_log = True + if args.logging_strategy == IntervalStrategy.STEPS and state.global_step % state.logging_steps == 0: + control.should_log = True + + # Evaluate + if ( + args.eval_strategy == IntervalStrategy.STEPS + and state.global_step % state.eval_steps == 0 + and args.eval_delay <= state.global_step + ): + control.should_evaluate = True + + # Save + if ( + args.save_strategy == SaveStrategy.STEPS + and state.save_steps > 0 + and state.global_step % state.save_steps == 0 + ): + control.should_save = True + + # End training + if state.global_step >= state.max_steps: + control.should_training_stop = True + # Save the model at the end if we have a save strategy + if args.save_strategy == SaveStrategy.STEPS: + control.should_save = True + + return control + + def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + # Log + if args.logging_strategy == IntervalStrategy.EPOCH: + control.should_log = True + + # Evaluate + if args.eval_strategy == IntervalStrategy.EPOCH and args.eval_delay <= state.epoch: + control.should_evaluate = True + + # Save + if args.save_strategy == SaveStrategy.EPOCH: + control.should_save = True + + return control + + +class ProgressCallback(TrainerCallback): + """ + A [`TrainerCallback`] that displays the progress of training or evaluation. + You can modify `max_str_len` to control how long strings are truncated when logging. + """ + + def __init__(self, max_str_len: int = 100): + """ + Initialize the callback with optional max_str_len parameter to control string truncation length. + + Args: + max_str_len (`int`): + Maximum length of strings to display in logs. + Longer strings will be truncated with a message. + """ + self.training_bar = None + self.prediction_bar = None + self.max_str_len = max_str_len + + def on_train_begin(self, args, state, control, **kwargs): + if state.is_world_process_zero: + self.training_bar = tqdm(total=state.max_steps, dynamic_ncols=True) + self.current_step = 0 + + def on_step_end(self, args, state, control, **kwargs): + if state.is_world_process_zero: + self.training_bar.update(state.global_step - self.current_step) + self.current_step = state.global_step + + def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): + if state.is_world_process_zero and has_length(eval_dataloader): + if self.prediction_bar is None: + self.prediction_bar = tqdm( + total=len(eval_dataloader), leave=self.training_bar is None, dynamic_ncols=True + ) + self.prediction_bar.update(1) + + def on_evaluate(self, args, state, control, **kwargs): + if state.is_world_process_zero: + if self.prediction_bar is not None: + self.prediction_bar.close() + self.prediction_bar = None + + def on_predict(self, args, state, control, **kwargs): + if state.is_world_process_zero: + if self.prediction_bar is not None: + self.prediction_bar.close() + self.prediction_bar = None + + def on_log(self, args, state, control, logs=None, **kwargs): + if state.is_world_process_zero and self.training_bar is not None: + # make a shallow copy of logs so we can mutate the fields copied + # but avoid doing any value pickling. + shallow_logs = {} + for k, v in logs.items(): + if isinstance(v, str) and len(v) > self.max_str_len: + shallow_logs[k] = ( + f"[String too long to display, length: {len(v)} > {self.max_str_len}. " + "Consider increasing `max_str_len` if needed.]" + ) + else: + shallow_logs[k] = v + _ = shallow_logs.pop("total_flos", None) + # round numbers so that it looks better in console + if "epoch" in shallow_logs: + shallow_logs["epoch"] = round(shallow_logs["epoch"], 2) + self.training_bar.write(str(shallow_logs)) + + def on_train_end(self, args, state, control, **kwargs): + if state.is_world_process_zero: + self.training_bar.close() + self.training_bar = None + + +class PrinterCallback(TrainerCallback): + """ + A bare [`TrainerCallback`] that just prints the logs. + """ + + def on_log(self, args, state, control, logs=None, **kwargs): + _ = logs.pop("total_flos", None) + if state.is_local_process_zero: + print(logs) + + +class EarlyStoppingCallback(TrainerCallback, ExportableState): + """ + A [`TrainerCallback`] that handles early stopping. + + Args: + early_stopping_patience (`int`): + Use with `metric_for_best_model` to stop training when the specified metric worsens for + `early_stopping_patience` evaluation calls. + early_stopping_threshold(`float`, *optional*): + Use with TrainingArguments `metric_for_best_model` and `early_stopping_patience` to denote how much the + specified metric must improve to satisfy early stopping conditions. ` + + This callback depends on [`TrainingArguments`] argument *load_best_model_at_end* functionality to set best_metric + in [`TrainerState`]. Note that if the [`TrainingArguments`] argument *save_steps* differs from *eval_steps*, the + early stopping will not occur until the next save step. + """ + + def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0): + self.early_stopping_patience = early_stopping_patience + self.early_stopping_threshold = early_stopping_threshold + # early_stopping_patience_counter denotes the number of times validation metrics failed to improve. + self.early_stopping_patience_counter = 0 + + def check_metric_value(self, args, state, control, metric_value): + # best_metric is set by code for load_best_model + operator = np.greater if args.greater_is_better else np.less + if state.best_metric is None or ( + operator(metric_value, state.best_metric) + and abs(metric_value - state.best_metric) > self.early_stopping_threshold + ): + self.early_stopping_patience_counter = 0 + else: + self.early_stopping_patience_counter += 1 + + def on_train_begin(self, args, state, control, **kwargs): + if not args.load_best_model_at_end: + logger.warning( + "Using EarlyStoppingCallback without load_best_model_at_end=True. " + "Once training is finished, the best model will not be loaded automatically." + ) + assert args.metric_for_best_model is not None, ( + "EarlyStoppingCallback requires metric_for_best_model to be defined" + ) + assert args.eval_strategy != IntervalStrategy.NO, ( + "EarlyStoppingCallback requires IntervalStrategy of steps or epoch" + ) + + def on_evaluate(self, args, state, control, metrics, **kwargs): + metric_to_check = args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + metric_value = metrics.get(metric_to_check) + + if metric_value is None: + logger.warning( + f"early stopping required metric_for_best_model, but did not find {metric_to_check} so early stopping" + " is disabled" + ) + return + + self.check_metric_value(args, state, control, metric_value) + if self.early_stopping_patience_counter >= self.early_stopping_patience: + control.should_training_stop = True + + def state(self) -> dict: + return { + "args": { + "early_stopping_patience": self.early_stopping_patience, + "early_stopping_threshold": self.early_stopping_threshold, + }, + "attributes": { + "early_stopping_patience_counter": self.early_stopping_patience_counter, + }, + } diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/trainer_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/trainer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2e71367c70c742adc12087a7a82d3ddfe94dca6b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/trainer_utils.py @@ -0,0 +1,911 @@ +# Copyright 2020-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PyTorch-independent utilities for the Trainer class. +""" + +import copy +import functools +import gc +import inspect +import os +import random +import re +import threading +import time +from typing import Any, Callable, NamedTuple, Optional, Union + +import numpy as np + +from .utils import ( + ExplicitEnum, + is_psutil_available, + is_tf_available, + is_torch_available, + is_torch_cuda_available, + is_torch_hpu_available, + is_torch_mlu_available, + is_torch_mps_available, + is_torch_musa_available, + is_torch_npu_available, + is_torch_xla_available, + is_torch_xpu_available, + requires_backends, +) + + +if is_torch_available(): + import torch + + +def seed_worker(worker_id: int, num_workers: int, rank: int): + """ + Helper function to set worker seed during Dataloader initialization. + """ + init_seed = torch.initial_seed() % 2**32 + worker_seed = num_workers * rank + init_seed + set_seed(worker_seed) + + +def enable_full_determinism(seed: int, warn_only: bool = False): + """ + Helper function for reproducible behavior during distributed training. See + - https://pytorch.org/docs/stable/notes/randomness.html for pytorch + - https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism for tensorflow + """ + # set seed first + set_seed(seed) + + if is_torch_available(): + # Enable PyTorch deterministic mode. This potentially requires either the environment + # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, + # depending on the CUDA version, so we set them both here + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + # The environment variable required to enable deterministic mode on Ascend NPUs. + os.environ["ASCEND_LAUNCH_BLOCKING"] = "1" + os.environ["HCCL_DETERMINISTIC"] = "1" + + os.environ["FLASH_ATTENTION_DETERMINISTIC"] = "1" + torch.use_deterministic_algorithms(True, warn_only=warn_only) + + # Enable CUDNN deterministic mode + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + if is_tf_available(): + import tensorflow as tf + + tf.config.experimental.enable_op_determinism() + + +def set_seed(seed: int, deterministic: bool = False): + """ + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed). + + Args: + seed (`int`): + The seed to set. + deterministic (`bool`, *optional*, defaults to `False`): + Whether to use deterministic algorithms where available. Can slow down training. + """ + random.seed(seed) + np.random.seed(seed) + if is_torch_available(): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # ^^ safe to call this function even if cuda is not available + if deterministic: + torch.use_deterministic_algorithms(True) + if is_torch_mlu_available(): + torch.mlu.manual_seed_all(seed) + if is_torch_musa_available(): + torch.musa.manual_seed_all(seed) + if is_torch_npu_available(): + torch.npu.manual_seed_all(seed) + if is_torch_hpu_available(): + torch.hpu.manual_seed_all(seed) + if is_torch_xpu_available(): + torch.xpu.manual_seed_all(seed) + if is_tf_available(): + import tensorflow as tf + + tf.random.set_seed(seed) + if deterministic: + tf.config.experimental.enable_op_determinism() + + +def neftune_post_forward_hook(module, input, output): + """ + Implements the NEFTune forward pass for the model using forward hooks. Note this works only for torch.nn.Embedding + layers. This method is slightly adapted from the original source code that can be found here: + https://github.com/neelsjain/NEFTune Simply add it to your model as follows: + ```python + model = ... + model.embed_tokens.neftune_noise_alpha = 0.1 + model.embed_tokens.register_forward_hook(neftune_post_forward_hook) + ``` + Args: + module (`torch.nn.Module`): + The embedding module where the hook is attached. Note that you need to set `module.neftune_noise_alpha` to + the desired noise alpha value. + input (`torch.Tensor`): + The input tensor to the model. + output (`torch.Tensor`): + The output tensor of the model (i.e. the embeddings). + """ + if module.training: + dims = torch.tensor(output.size(1) * output.size(2)) + mag_norm = module.neftune_noise_alpha / torch.sqrt(dims) + output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm) + return output + + +class EvalPrediction: + """ + Evaluation output (always contains labels), to be used to compute metrics. + + Parameters: + predictions (`np.ndarray`): Predictions of the model. + label_ids (`np.ndarray`): Targets to be matched. + inputs (`np.ndarray`, *optional*): Input data passed to the model. + losses (`np.ndarray`, *optional*): Loss values computed during evaluation. + """ + + def __init__( + self, + predictions: Union[np.ndarray, tuple[np.ndarray]], + label_ids: Union[np.ndarray, tuple[np.ndarray]], + inputs: Optional[Union[np.ndarray, tuple[np.ndarray]]] = None, + losses: Optional[Union[np.ndarray, tuple[np.ndarray]]] = None, + ): + self.predictions = predictions + self.label_ids = label_ids + self.inputs = inputs + self.losses = losses + self.elements = (self.predictions, self.label_ids) + if self.inputs is not None: + self.elements += (self.inputs,) + if self.losses is not None: + self.elements += (self.losses,) + + def __iter__(self): + return iter(self.elements) + + def __getitem__(self, idx): + if idx < 0 or idx >= len(self.elements): + raise IndexError("tuple index out of range") + return self.elements[idx] + + +class EvalLoopOutput(NamedTuple): + predictions: Union[np.ndarray, tuple[np.ndarray]] + label_ids: Optional[Union[np.ndarray, tuple[np.ndarray]]] + metrics: Optional[dict[str, float]] + num_samples: Optional[int] + + +class PredictionOutput(NamedTuple): + predictions: Union[np.ndarray, tuple[np.ndarray]] + label_ids: Optional[Union[np.ndarray, tuple[np.ndarray]]] + metrics: Optional[dict[str, float]] + + +class TrainOutput(NamedTuple): + global_step: int + training_loss: float + metrics: dict[str, float] + + +PREFIX_CHECKPOINT_DIR = "checkpoint" +_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$") + + +def get_last_checkpoint(folder): + content = os.listdir(folder) + checkpoints = [ + path + for path in content + if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path)) + ] + if len(checkpoints) == 0: + return + return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0]))) + + +class IntervalStrategy(ExplicitEnum): + NO = "no" + STEPS = "steps" + EPOCH = "epoch" + + +class SaveStrategy(ExplicitEnum): + NO = "no" + STEPS = "steps" + EPOCH = "epoch" + BEST = "best" + + +class EvaluationStrategy(ExplicitEnum): + NO = "no" + STEPS = "steps" + EPOCH = "epoch" + + +class HubStrategy(ExplicitEnum): + END = "end" + EVERY_SAVE = "every_save" + CHECKPOINT = "checkpoint" + ALL_CHECKPOINTS = "all_checkpoints" + + +class BestRun(NamedTuple): + """ + The best run found by a hyperparameter search (see [`~Trainer.hyperparameter_search`]). + + Parameters: + run_id (`str`): + The id of the best run (if models were saved, the corresponding checkpoint will be in the folder ending + with run-{run_id}). + objective (`float`): + The objective that was obtained for this run. + hyperparameters (`dict[str, Any]`): + The hyperparameters picked to get this run. + run_summary (`Optional[Any]`): + A summary of tuning experiments. `ray.tune.ExperimentAnalysis` object for Ray backend. + """ + + run_id: str + objective: Union[float, list[float]] + hyperparameters: dict[str, Any] + run_summary: Optional[Any] = None + + +def default_compute_objective(metrics: dict[str, float]) -> float: + """ + The default objective to maximize/minimize when doing an hyperparameter search. It is the evaluation loss if no + metrics are provided to the [`Trainer`], the sum of all metrics otherwise. + + Args: + metrics (`dict[str, float]`): The metrics returned by the evaluate method. + + Return: + `float`: The objective to minimize or maximize + """ + metrics = copy.deepcopy(metrics) + loss = metrics.pop("eval_loss", None) + _ = metrics.pop("epoch", None) + # Remove speed metrics + speed_metrics = [ + m for m in metrics if m.endswith("_runtime") or m.endswith("_per_second") or m.endswith("_compilation_time") + ] + for sm in speed_metrics: + _ = metrics.pop(sm, None) + return loss if len(metrics) == 0 else sum(metrics.values()) + + +def default_hp_space_optuna(trial) -> dict[str, float]: + from .integrations import is_optuna_available + + assert is_optuna_available(), "This function needs Optuna installed: `pip install optuna`" + return { + "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True), + "num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5), + "seed": trial.suggest_int("seed", 1, 40), + "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16, 32, 64]), + } + + +def default_hp_space_ray(trial) -> dict[str, Any]: + from .integrations import is_ray_tune_available + + assert is_ray_tune_available(), "This function needs ray installed: `pip install ray[tune]`" + from ray import tune + + return { + "learning_rate": tune.loguniform(1e-6, 1e-4), + "num_train_epochs": tune.choice(list(range(1, 6))), + "seed": tune.uniform(1, 40), + "per_device_train_batch_size": tune.choice([4, 8, 16, 32, 64]), + } + + +def default_hp_space_sigopt(trial): + return [ + {"bounds": {"min": 1e-6, "max": 1e-4}, "name": "learning_rate", "type": "double", "transformation": "log"}, + {"bounds": {"min": 1, "max": 6}, "name": "num_train_epochs", "type": "int"}, + {"bounds": {"min": 1, "max": 40}, "name": "seed", "type": "int"}, + { + "categorical_values": ["4", "8", "16", "32", "64"], + "name": "per_device_train_batch_size", + "type": "categorical", + }, + ] + + +def default_hp_space_wandb(trial) -> dict[str, Any]: + from .integrations import is_wandb_available + + if not is_wandb_available(): + raise ImportError("This function needs wandb installed: `pip install wandb`") + + return { + "method": "random", + "metric": {"name": "objective", "goal": "minimize"}, + "parameters": { + "learning_rate": {"distribution": "uniform", "min": 1e-6, "max": 1e-4}, + "num_train_epochs": {"distribution": "int_uniform", "min": 1, "max": 6}, + "seed": {"distribution": "int_uniform", "min": 1, "max": 40}, + "per_device_train_batch_size": {"values": [4, 8, 16, 32, 64]}, + }, + } + + +class HPSearchBackend(ExplicitEnum): + OPTUNA = "optuna" + RAY = "ray" + SIGOPT = "sigopt" + WANDB = "wandb" + + +def is_main_process(local_rank): + """ + Whether or not the current process is the local process, based on `xr.global_ordinal()` (for TPUs) first, then on + `local_rank`. + """ + if is_torch_xla_available(): + import torch_xla.runtime as xr + + return xr.global_ordinal() == 0 + return local_rank in [-1, 0] + + +def total_processes_number(local_rank): + """ + Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs. + """ + if is_torch_xla_available(): + import torch_xla.runtime as xr + + return xr.world_size() + elif local_rank != -1 and is_torch_available(): + import torch + + return torch.distributed.get_world_size() + return 1 + + +def speed_metrics(split, start_time, num_samples=None, num_steps=None, num_tokens=None): + """ + Measure and return speed performance metrics. + + This function requires a time snapshot `start_time` before the operation to be measured starts and this function + should be run immediately after the operation to be measured has completed. + + Args: + - split: name to prefix metric (like train, eval, test...) + - start_time: operation start time + - num_samples: number of samples processed + - num_steps: number of steps processed + - num_tokens: number of tokens processed + """ + runtime = time.time() - start_time + result = {f"{split}_runtime": round(runtime, 4)} + if runtime == 0: + return result + if num_samples is not None: + samples_per_second = num_samples / runtime + result[f"{split}_samples_per_second"] = round(samples_per_second, 3) + if num_steps is not None: + steps_per_second = num_steps / runtime + result[f"{split}_steps_per_second"] = round(steps_per_second, 3) + if num_tokens is not None: + tokens_per_second = num_tokens / runtime + result[f"{split}_tokens_per_second"] = round(tokens_per_second, 3) + return result + + +class SchedulerType(ExplicitEnum): + """ + Scheduler names for the parameter `lr_scheduler_type` in [`TrainingArguments`]. + By default, it uses "linear". Internally, this retrieves `get_linear_schedule_with_warmup` scheduler from [`Trainer`]. + Scheduler types: + - "linear" = [`get_linear_schedule_with_warmup`] + - "cosine" = [`get_cosine_schedule_with_warmup`] + - "cosine_with_restarts" = [`get_cosine_with_hard_restarts_schedule_with_warmup`] + - "polynomial" = [`get_polynomial_decay_schedule_with_warmup`] + - "constant" = [`get_constant_schedule`] + - "constant_with_warmup" = [`get_constant_schedule_with_warmup`] + - "inverse_sqrt" = [`get_inverse_sqrt_schedule`] + - "reduce_lr_on_plateau" = [`get_reduce_on_plateau_schedule`] + - "cosine_with_min_lr" = [`get_cosine_with_min_lr_schedule_with_warmup`] + - "cosine_warmup_with_min_lr" = [`get_cosine_with_min_lr_schedule_with_warmup_lr_rate`] + - "warmup_stable_decay" = [`get_wsd_schedule`] + """ + + LINEAR = "linear" + COSINE = "cosine" + COSINE_WITH_RESTARTS = "cosine_with_restarts" + POLYNOMIAL = "polynomial" + CONSTANT = "constant" + CONSTANT_WITH_WARMUP = "constant_with_warmup" + INVERSE_SQRT = "inverse_sqrt" + REDUCE_ON_PLATEAU = "reduce_lr_on_plateau" + COSINE_WITH_MIN_LR = "cosine_with_min_lr" + COSINE_WARMUP_WITH_MIN_LR = "cosine_warmup_with_min_lr" + WARMUP_STABLE_DECAY = "warmup_stable_decay" + + +class TrainerMemoryTracker: + """ + A helper class that tracks cpu and gpu memory. + + This class will silently skip unless `psutil` is available. Install with `pip install psutil`. + + When a stage completes, it can pass metrics dict to update with the memory metrics gathered during this stage. + + Example : + + ```python + self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) + self._memory_tracker.start() + # code ... + metrics = {"train_runtime": 10.5} + self._memory_tracker.stop_and_update_metrics(metrics) + ``` + + At the moment GPU tracking is only for `pytorch`, but can be extended to support `tensorflow`. + + To understand this class' intricacies please read the documentation of [`~Trainer.log_metrics`]. + """ + + # map trainer methods to metrics prefix + stages = { + "__init__": "init", + "train": "train", + "_inner_training_loop": "train", + "evaluate": "eval", + "predict": "test", + } + + def __init__(self, skip_memory_metrics=False): + self.skip_memory_metrics = skip_memory_metrics + + if not is_psutil_available(): + # soft dependency on psutil + self.skip_memory_metrics = True + + if self.skip_memory_metrics: + return + + import psutil + + if is_torch_cuda_available() or is_torch_mlu_available() or is_torch_musa_available(): + import torch + + self.torch = torch + self.gpu = {} + elif is_torch_mps_available(): + import torch + + self.torch = torch + self.gpu = {} + elif is_torch_xpu_available(): + import torch + + self.torch = torch + self.gpu = {} + elif is_torch_npu_available(): + import torch + + self.torch = torch + self.gpu = {} + elif is_torch_hpu_available(): + import torch + + self.torch = torch + self.gpu = {} + else: + self.torch = None + + self.process = psutil.Process() + + self.cur_stage = None + self.cpu = {} + self.init_reported = False + + def derive_stage(self): + """derives the stage/caller name automatically""" + caller = inspect.currentframe().f_back.f_back.f_code.co_name + if caller in self.stages: + return self.stages[caller] + else: + raise ValueError( + f"was called from {caller}, but only expect to be called from one of {self.stages.keys()}" + ) + + def cpu_mem_used(self): + """get resident set size memory for the current process""" + return self.process.memory_info().rss + + def peak_monitor_func(self): + self.cpu_mem_used_peak = -1 + + while True: + self.cpu_mem_used_peak = max(self.cpu_mem_used(), self.cpu_mem_used_peak) + + # can't sleep or will not catch the peak right (this comment is here on purpose) + # time.sleep(0.001) # 1msec + + if not self.peak_monitoring: + break + + def start(self): + """start tracking for the caller's stage""" + if self.skip_memory_metrics: + return + + stage = self.derive_stage() + # deal with nested calls of eval during train - simply ignore those + if self.cur_stage is not None and self.cur_stage != stage: + return + + self.cur_stage = stage + + gc.collect() + + if self.torch is not None: + if torch.cuda.is_available(): + self.torch.cuda.reset_peak_memory_stats() + self.torch.cuda.empty_cache() + elif is_torch_mlu_available(): + self.torch.mlu.reset_peak_memory_stats() + self.torch.mlu.empty_cache() + elif is_torch_musa_available(): + self.torch.musa.reset_peak_memory_stats() + self.torch.musa.empty_cache() + elif is_torch_xpu_available(): + self.torch.xpu.reset_peak_memory_stats() + self.torch.xpu.empty_cache() + elif is_torch_npu_available(): + self.torch.npu.reset_peak_memory_stats() + self.torch.npu.empty_cache() + elif is_torch_hpu_available(): + self.torch.hpu.reset_peak_memory_stats() + # not available on hpu as it reserves all device memory for the current process + # self.torch.hpu.empty_cache() + elif is_torch_mps_available(): + self.torch.mps.empty_cache() + + # gpu + if self.torch is not None: + if torch.cuda.is_available(): + self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated() + elif is_torch_mlu_available(): + self.gpu_mem_used_at_start = self.torch.mlu.memory_allocated() + elif is_torch_musa_available(): + self.gpu_mem_used_at_start = self.torch.musa.memory_allocated() + elif is_torch_xpu_available(): + self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated() + elif is_torch_npu_available(): + self.gpu_mem_used_at_start = self.torch.npu.memory_allocated() + elif is_torch_hpu_available(): + self.gpu_mem_used_at_start = self.torch.hpu.memory_allocated() + elif is_torch_mps_available(): + self.gpu_mem_used_at_start = self.torch.mps.current_allocated_memory() + + # cpu + self.cpu_mem_used_at_start = self.cpu_mem_used() + + self.peak_monitoring = True + peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) + peak_monitor_thread.daemon = True + peak_monitor_thread.start() + + def stop(self, stage): + """stop tracking for the passed stage""" + + # deal with nested calls of eval during train - simply ignore those + if self.cur_stage is not None and self.cur_stage != stage: + return + + # this sends a signal to peak_monitor_func to complete its loop + self.peak_monitoring = False + + # first ensure all objects get collected and their memory is freed + gc.collect() + + if self.torch is not None: + if torch.cuda.is_available(): + self.torch.cuda.empty_cache() + elif is_torch_mlu_available(): + self.torch.mlu.empty_cache() + elif is_torch_musa_available(): + self.torch.musa.empty_cache() + elif is_torch_xpu_available(): + self.torch.xpu.empty_cache() + elif is_torch_npu_available(): + self.torch.npu.empty_cache() + elif is_torch_hpu_available(): + # not available on hpu as it reserves all device memory for the current process + # self.torch.npu.empty_cache() + pass + elif is_torch_mps_available(): + self.torch.mps.empty_cache() + + # concepts: + # - alloc_delta: the difference of allocated memory between the end and the start + # - peaked_delta: the difference between the peak memory and the current memory + # in order to know how much memory the measured code consumed one needs to sum these two + + # gpu + if self.torch is not None: + if torch.cuda.is_available(): + self.gpu_mem_used_now = self.torch.cuda.memory_allocated() + self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated() + elif is_torch_mlu_available(): + self.gpu_mem_used_now = self.torch.mlu.memory_allocated() + self.gpu_mem_used_peak = self.torch.mlu.max_memory_allocated() + elif is_torch_musa_available(): + self.gpu_mem_used_now = self.torch.musa.memory_allocated() + self.gpu_mem_used_peak = self.torch.musa.max_memory_allocated() + elif is_torch_xpu_available(): + self.gpu_mem_used_now = self.torch.xpu.memory_allocated() + self.gpu_mem_used_peak = self.torch.xpu.max_memory_allocated() + elif is_torch_npu_available(): + self.gpu_mem_used_now = self.torch.npu.memory_allocated() + self.gpu_mem_used_peak = self.torch.npu.max_memory_allocated() + elif is_torch_hpu_available(): + self.gpu_mem_used_now = self.torch.hpu.memory_allocated() + self.gpu_mem_used_peak = self.torch.hpu.max_memory_allocated() + elif is_torch_mps_available(): + self.gpu_mem_used_now = self.torch.mps.current_allocated_memory() + # self.torch.mps.max_memory_allocated() does not exist yet + self.gpu_mem_used_peak = None + + else: + raise ValueError("No available GPU device found!") + + self.gpu[self.cur_stage] = { + "begin": self.gpu_mem_used_at_start, + "end": self.gpu_mem_used_now, + "alloc": (self.gpu_mem_used_now - self.gpu_mem_used_at_start), + } + if self.gpu_mem_used_peak is not None: + self.gpu[self.cur_stage]["peaked"] = max(0, self.gpu_mem_used_peak - self.gpu_mem_used_now) + else: + self.gpu[self.cur_stage]["peaked"] = "Not available" + + # cpu + self.cpu_mem_used_now = self.cpu_mem_used() + self.cpu[self.cur_stage] = { + "begin": self.cpu_mem_used_at_start, + "end": self.cpu_mem_used_now, + "alloc": (self.cpu_mem_used_now - self.cpu_mem_used_at_start), + "peaked": max(0, self.cpu_mem_used_peak - self.cpu_mem_used_now), + } + + # reset - cycle finished + self.cur_stage = None + + def update_metrics(self, stage, metrics): + """updates the metrics""" + if self.skip_memory_metrics: + return + + # deal with nested calls of eval during train - simply ignore those + if self.cur_stage is not None and self.cur_stage != stage: + return + + # since we don't have a way to return init metrics, we push them into the first of train/val/predict + stages = [stage] + if not self.init_reported: + stages.insert(0, "init") + self.init_reported = True + + for stage in stages: + for t in ["alloc", "peaked"]: + if stage in self.cpu and t in self.cpu[stage]: + metrics[f"{stage}_mem_cpu_{t}_delta"] = self.cpu[stage][t] + if self.torch is not None and stage in self.gpu and t in self.gpu[stage]: + metrics[f"{stage}_mem_gpu_{t}_delta"] = self.gpu[stage][t] + # if we need additional debug info, enable the following + # for t in ["begin", "end"]: + # if stage in self.cpu and t in self.cpu[stage]: + # metrics[f"{stage}_mem_cpu_{t}"] = self.cpu[stage][t] + # if self.torch is not None and stage in self.gpu and t in self.gpu[stage]: + # metrics[f"{stage}_mem_gpu_{t}"] = self.gpu[stage][t] + + # since memory can be allocated before init, and it might be difficult to track overall + # memory usage, in particular for GPU, let's report memory usage at the point init was called + if stages[0] == "init": + metrics["before_init_mem_cpu"] = self.cpu["init"]["begin"] + if self.torch is not None: + metrics["before_init_mem_gpu"] = self.gpu["init"]["begin"] + # if we also wanted to report any additional memory allocations in between init and + # whatever the next stage was we could also report this: + # if self.cpu["init"]["end"] != self.cpu[stage]["begin"]: + # metrics[f"after_init_mem_cpu_delta"] = self.cpu[stage]["begin"] - self.cpu["init"]["end"] + # if self.torch is not None and self.gpu["init"]["end"] != self.gpu[stage]["begin"]: + # metrics[f"after_init_mem_gpu_delta"] = self.gpu[stage]["begin"] - self.gpu["init"]["end"] + + def stop_and_update_metrics(self, metrics=None): + """combine stop and metrics update in one call for simpler code""" + if self.skip_memory_metrics: + return + + stage = self.derive_stage() + self.stop(stage) + + # init doesn't have metrics to update so we just save that data for later stages to retrieve + if metrics is not None: + self.update_metrics(stage, metrics) + + +def has_length(dataset): + """ + Checks if the dataset implements __len__() and it doesn't raise an error + """ + try: + return len(dataset) is not None + except TypeError: + # TypeError: len() of unsized object + return False + except AttributeError: + # Ray DataSets raises an AttributeError: https://github.com/ray-project/ray/blob/master/python/ray/data/dataset.py#L5616 + return False + + +def denumpify_detensorize(metrics): + """ + Recursively calls `.item()` on the element of the dictionary passed + """ + if isinstance(metrics, (list, tuple)): + return type(metrics)(denumpify_detensorize(m) for m in metrics) + elif isinstance(metrics, dict): + return type(metrics)({k: denumpify_detensorize(v) for k, v in metrics.items()}) + elif isinstance(metrics, np.generic): + return metrics.item() + elif is_torch_available() and isinstance(metrics, torch.Tensor) and metrics.numel() == 1: + return metrics.item() + return metrics + + +def number_of_arguments(func): + """ + Return the number of arguments of the passed function, even if it's a partial function. + """ + if isinstance(func, functools.partial): + total_args = len(inspect.signature(func.func).parameters) + return total_args - len(func.args) - len(func.keywords) + return len(inspect.signature(func).parameters) + + +def find_executable_batch_size( + function: Optional[Callable] = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False +): + """ + Args: + A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or + CUDNN, the batch size is multiplied by 0.9 and passed to `function`. `function` must take in a `batch_size` parameter as + its first argument. + function (`Callable`, *optional*) + A function to wrap + starting_batch_size (`int`, *optional*) + The batch size to try and fit into memory + auto_find_batch_size (`bool`, *optional*) + If False, will just execute `function` + """ + if function is None: + return functools.partial( + find_executable_batch_size, + starting_batch_size=starting_batch_size, + auto_find_batch_size=auto_find_batch_size, + ) + + if auto_find_batch_size: + requires_backends(find_executable_batch_size, "accelerate") + from accelerate.utils import find_executable_batch_size as accelerate_find_executable_batch_size + + return accelerate_find_executable_batch_size(function=function, starting_batch_size=starting_batch_size) + + return functools.partial(function, batch_size=starting_batch_size) + + +class FSDPOption(ExplicitEnum): + FULL_SHARD = "full_shard" + SHARD_GRAD_OP = "shard_grad_op" + NO_SHARD = "no_shard" + HYBRID_SHARD = "hybrid_shard" + HYBRID_SHARD_ZERO2 = "hybrid_shard_zero2" + OFFLOAD = "offload" + AUTO_WRAP = "auto_wrap" + + +class RemoveColumnsCollator: + """Wrap the data collator to remove unused columns before they are passed to the collator.""" + + def __init__( + self, + data_collator, + signature_columns, + logger=None, + model_name: Optional[str] = None, + description: Optional[str] = None, + ): + self.data_collator = data_collator + self.signature_columns = signature_columns + self.logger = logger + self.description = description + self.model_name = model_name + self.message_logged = False + + def _remove_columns(self, feature: dict) -> dict: + if not isinstance(feature, dict): + return feature + if not self.message_logged and self.logger and self.model_name: + ignored_columns = list(set(feature.keys()) - set(self.signature_columns)) + if len(ignored_columns) > 0: + dset_description = "" if self.description is None else f"in the {self.description} set" + self.logger.info( + f"The following columns {dset_description} don't have a corresponding argument in " + f"`{self.model_name}.forward` and have been ignored: {', '.join(ignored_columns)}." + f" If {', '.join(ignored_columns)} are not expected by `{self.model_name}.forward`, " + " you can safely ignore this message." + ) + self.message_logged = True + return {k: v for k, v in feature.items() if k in self.signature_columns} + + def __call__(self, features: list[dict]): + features = [self._remove_columns(feature) for feature in features] + return self.data_collator(features) + + +def check_target_module_exists(optim_target_modules, key: str, return_is_regex: bool = False): + """A helper method to check if the passed module's key name matches any of the target modules in the optim_target_modules. + + Args: + optim_target_modules (`Union[str, list[str]]`): + A list of strings to try to match. Can be also a full string. + key (`str`): + A key to search any matches in optim_target_modules + return_is_regex (`bool`): + If set to `True`, the method will return whether the passed `optim_target_modules` + is a regex or not. + + Returns: + `bool` : True of match object if key matches any target modules from config, False or + None if no match found + `bool` : If the matched target module is a regex to silence out the warnings in Trainer + for extra modules being found (only if `target_module_found=True` for an array of regex). + """ + target_module_found = False + is_regex = False + + if isinstance(optim_target_modules, str): + target_module_found = bool(re.fullmatch(optim_target_modules, key)) + is_regex = optim_target_modules != key + elif key in optim_target_modules: # from here, target_module_found must be a list of str + # this module is specified directly in target_modules + target_module_found = True + elif any(target_key in key for target_key in optim_target_modules): + target_module_found = True + elif any(bool(re.fullmatch(optim_target_module, key)) for optim_target_module in optim_target_modules): + target_module_found = True + is_regex = True + + if return_is_regex: + return target_module_found, is_regex + + return target_module_found diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/training_args.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/training_args.py new file mode 100644 index 0000000000000000000000000000000000000000..79fdfe6c0c2e824eae03d09b5156312e0fe587c3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/training_args.py @@ -0,0 +1,3167 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import json +import math +import os +import warnings +from dataclasses import asdict, dataclass, field, fields +from datetime import timedelta +from enum import Enum +from functools import cached_property +from pathlib import Path +from typing import Any, Optional, Union + +from huggingface_hub import get_full_repo_name + +from .debug_utils import DebugOption +from .trainer_utils import ( + EvaluationStrategy, + FSDPOption, + HubStrategy, + IntervalStrategy, + SaveStrategy, + SchedulerType, +) +from .utils import ( + ACCELERATE_MIN_VERSION, + ExplicitEnum, + is_accelerate_available, + is_apex_available, + is_ipex_available, + is_sagemaker_dp_enabled, + is_sagemaker_mp_enabled, + is_torch_available, + is_torch_bf16_gpu_available, + is_torch_cuda_available, + is_torch_hpu_available, + is_torch_mlu_available, + is_torch_mps_available, + is_torch_musa_available, + is_torch_neuroncore_available, + is_torch_npu_available, + is_torch_tf32_available, + is_torch_xla_available, + is_torch_xpu_available, + logging, + requires_backends, +) +from .utils.generic import strtobool +from .utils.import_utils import is_optimum_neuron_available + + +logger = logging.get_logger(__name__) +log_levels = logging.get_log_levels_dict().copy() +trainer_log_levels = dict(**log_levels, passive=-1) + +if is_torch_available(): + import torch + import torch.distributed as dist + +if is_accelerate_available(): + from accelerate.state import AcceleratorState, PartialState + from accelerate.utils import DistributedType + + from .trainer_pt_utils import AcceleratorConfig + +if is_accelerate_available("1.10.1"): + from accelerate.parallelism_config import ParallelismConfig +else: + ParallelismConfig = Any + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + +if is_torch_neuroncore_available(check_device=False): + # torchrun support + # https://github.com/pytorch/xla/pull/3609 + if os.environ.get("TORCHELASTIC_RUN_ID"): + if is_optimum_neuron_available(): + logger.info( + "Make sure that you are performing the training with the NeuronTrainer from optimum[neuron], this " + "will fail otherwise." + ) + else: + logger.warning( + "Please use the NeuronTrainer from optimum[neuron] instead of the Transformers library to perform " + "training on AWS Trainium instances. More information here: " + "https://github.com/huggingface/optimum-neuron" + ) + import torch_xla.distributed.xla_backend as xbn + + if not isinstance(dist.group.WORLD, xbn.ProcessGroupXla): + dist.init_process_group(backend="xla") + if not isinstance(dist.group.WORLD, xbn.ProcessGroupXla): + raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.") + + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + + smp.init() + + +def default_logdir() -> str: + """ + Same default as PyTorch + """ + import socket + from datetime import datetime + + current_time = datetime.now().strftime("%b%d_%H-%M-%S") + return os.path.join("runs", current_time + "_" + socket.gethostname()) + + +def get_int_from_env(env_keys, default): + """Returns the first positive env value found in the `env_keys` list or the default.""" + for e in env_keys: + val = int(os.environ.get(e, "-1")) + if val >= 0: + return val + return default + + +def get_xla_device_type(device: "torch.device") -> Optional[str]: + """ + Returns the xla device type (CPU|GPU|TPU) or None if the device is a non-xla device. + """ + if is_torch_xla_available(): + if device.type == "cpu": + return "CPU" + return xm.xla_real_devices([device])[0].split(":")[0] + return None + + +class OptimizerNames(ExplicitEnum): + """ + Stores the acceptable string identifiers for optimizers. + """ + + ADAMW_TORCH = "adamw_torch" + ADAMW_TORCH_FUSED = "adamw_torch_fused" + ADAMW_TORCH_XLA = "adamw_torch_xla" + ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused" + ADAMW_APEX_FUSED = "adamw_apex_fused" + ADAFACTOR = "adafactor" + ADAMW_ANYPRECISION = "adamw_anyprecision" + ADAMW_TORCH_4BIT = "adamw_torch_4bit" + ADAMW_TORCH_8BIT = "adamw_torch_8bit" + ADEMAMIX = "ademamix" + SGD = "sgd" + ADAGRAD = "adagrad" + ADAMW_BNB = "adamw_bnb_8bit" + ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit + ADEMAMIX_8BIT = "ademamix_8bit" + LION_8BIT = "lion_8bit" + LION = "lion_32bit" + PAGED_ADAMW = "paged_adamw_32bit" + PAGED_ADAMW_8BIT = "paged_adamw_8bit" + PAGED_ADEMAMIX = "paged_ademamix_32bit" + PAGED_ADEMAMIX_8BIT = "paged_ademamix_8bit" + PAGED_LION = "paged_lion_32bit" + PAGED_LION_8BIT = "paged_lion_8bit" + RMSPROP = "rmsprop" + RMSPROP_BNB = "rmsprop_bnb" + RMSPROP_8BIT = "rmsprop_bnb_8bit" + RMSPROP_32BIT = "rmsprop_bnb_32bit" + GALORE_ADAMW = "galore_adamw" + GALORE_ADAMW_8BIT = "galore_adamw_8bit" + GALORE_ADAFACTOR = "galore_adafactor" + GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise" + GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise" + GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise" + LOMO = "lomo" + ADALOMO = "adalomo" + GROKADAMW = "grokadamw" + SCHEDULE_FREE_RADAM = "schedule_free_radam" + SCHEDULE_FREE_ADAMW = "schedule_free_adamw" + SCHEDULE_FREE_SGD = "schedule_free_sgd" + APOLLO_ADAMW = "apollo_adamw" + APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise" + STABLE_ADAMW = "stable_adamw" + + +def _convert_str_dict(passed_value: dict): + "Safely checks that a passed value is a dictionary and converts any string values to their appropriate types." + for key, value in passed_value.items(): + if isinstance(value, dict): + passed_value[key] = _convert_str_dict(value) + elif isinstance(value, str): + # First check for bool and convert + if value.lower() in ("true", "false"): + passed_value[key] = value.lower() == "true" + # Check for digit + elif value.isdigit(): + passed_value[key] = int(value) + elif value.replace(".", "", 1).isdigit(): + passed_value[key] = float(value) + + return passed_value + + +# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903 +@dataclass +class TrainingArguments: + """ + TrainingArguments is the subset of the arguments we use in our example scripts **which relate to the training loop + itself**. + + Using [`HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + output_dir (`str`, *optional*, defaults to `"trainer_output"`): + The output directory where the model predictions and checkpoints will be written. + overwrite_output_dir (`bool`, *optional*, defaults to `False`): + If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir` + points to a checkpoint directory. + do_train (`bool`, *optional*, defaults to `False`): + Whether to run training or not. This argument is not directly used by [`Trainer`], it's intended to be used + by your training/evaluation scripts instead. See the [example + scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details. + do_eval (`bool`, *optional*): + Whether to run evaluation on the validation set or not. Will be set to `True` if `eval_strategy` is + different from `"no"`. This argument is not directly used by [`Trainer`], it's intended to be used by your + training/evaluation scripts instead. See the [example + scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details. + do_predict (`bool`, *optional*, defaults to `False`): + Whether to run predictions on the test set or not. This argument is not directly used by [`Trainer`], it's + intended to be used by your training/evaluation scripts instead. See the [example + scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details. + eval_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`): + The evaluation strategy to adopt during training. Possible values are: + + - `"no"`: No evaluation is done during training. + - `"steps"`: Evaluation is done (and logged) every `eval_steps`. + - `"epoch"`: Evaluation is done at the end of each epoch. + + prediction_loss_only (`bool`, *optional*, defaults to `False`): + When performing evaluation and generating predictions, only returns the loss. + per_device_train_batch_size (`int`, *optional*, defaults to 8): + The batch size *per device*. The **global batch size** is computed as: + `per_device_train_batch_size * number_of_devices` in multi-GPU or distributed setups. + per_device_eval_batch_size (`int`, *optional*, defaults to 8): + The batch size per device accelerator core/CPU for evaluation. + gradient_accumulation_steps (`int`, *optional*, defaults to 1): + Number of updates steps to accumulate the gradients for, before performing a backward/update pass. + + + + When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging, + evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training examples. + + + + eval_accumulation_steps (`int`, *optional*): + Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. If + left unset, the whole predictions are accumulated on the device accelerator before being moved to the CPU (faster but + requires more memory). + eval_delay (`float`, *optional*): + Number of epochs or steps to wait for before the first evaluation can be performed, depending on the + eval_strategy. + torch_empty_cache_steps (`int`, *optional*): + Number of steps to wait before calling `torch..empty_cache()`. If left unset or set to None, cache will not be emptied. + + + + This can help avoid CUDA out-of-memory errors by lowering peak VRAM usage at a cost of about [10% slower performance](https://github.com/huggingface/transformers/issues/31372). + + + + learning_rate (`float`, *optional*, defaults to 5e-5): + The initial learning rate for [`AdamW`] optimizer. + weight_decay (`float`, *optional*, defaults to 0): + The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights in [`AdamW`] + optimizer. + adam_beta1 (`float`, *optional*, defaults to 0.9): + The beta1 hyperparameter for the [`AdamW`] optimizer. + adam_beta2 (`float`, *optional*, defaults to 0.999): + The beta2 hyperparameter for the [`AdamW`] optimizer. + adam_epsilon (`float`, *optional*, defaults to 1e-8): + The epsilon hyperparameter for the [`AdamW`] optimizer. + max_grad_norm (`float`, *optional*, defaults to 1.0): + Maximum gradient norm (for gradient clipping). + num_train_epochs(`float`, *optional*, defaults to 3.0): + Total number of training epochs to perform (if not an integer, will perform the decimal part percents of + the last epoch before stopping training). + max_steps (`int`, *optional*, defaults to -1): + If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`. + For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until + `max_steps` is reached. + lr_scheduler_type (`str` or [`SchedulerType`], *optional*, defaults to `"linear"`): + The scheduler type to use. See the documentation of [`SchedulerType`] for all possible values. + lr_scheduler_kwargs (`dict` or `str`, *optional*, defaults to `None`): + The extra arguments for the lr_scheduler. See the documentation of each scheduler for possible values. + warmup_ratio (`float`, *optional*, defaults to 0.0): + Ratio of total training steps used for a linear warmup from 0 to `learning_rate`. + warmup_steps (`int`, *optional*, defaults to 0): + Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`. + log_level (`str`, *optional*, defaults to `passive`): + Logger log level to use on the main process. Possible choices are the log levels as strings: 'debug', + 'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and keeps the + current log level for the Transformers library (which will be `"warning"` by default). + log_level_replica (`str`, *optional*, defaults to `"warning"`): + Logger log level to use on replicas. Same choices as `log_level`" + log_on_each_node (`bool`, *optional*, defaults to `True`): + In multinode distributed training, whether to log using `log_level` once per node, or only on the main + node. + logging_dir (`str`, *optional*): + [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to + *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***. + logging_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`): + The logging strategy to adopt during training. Possible values are: + + - `"no"`: No logging is done during training. + - `"epoch"`: Logging is done at the end of each epoch. + - `"steps"`: Logging is done every `logging_steps`. + + logging_first_step (`bool`, *optional*, defaults to `False`): + Whether to log the first `global_step` or not. + logging_steps (`int` or `float`, *optional*, defaults to 500): + Number of update steps between two logs if `logging_strategy="steps"`. Should be an integer or a float in + range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps. + logging_nan_inf_filter (`bool`, *optional*, defaults to `True`): + Whether to filter `nan` and `inf` losses for logging. If set to `True` the loss of every step that is `nan` + or `inf` is filtered and the average loss of the current logging window is taken instead. + + + + `logging_nan_inf_filter` only influences the logging of loss values, it does not change the behavior the + gradient is computed or applied to the model. + + + + save_strategy (`str` or [`~trainer_utils.SaveStrategy`], *optional*, defaults to `"steps"`): + The checkpoint save strategy to adopt during training. Possible values are: + + - `"no"`: No save is done during training. + - `"epoch"`: Save is done at the end of each epoch. + - `"steps"`: Save is done every `save_steps`. + - `"best"`: Save is done whenever a new `best_metric` is achieved. + + If `"epoch"` or `"steps"` is chosen, saving will also be performed at the + very end of training, always. + save_steps (`int` or `float`, *optional*, defaults to 500): + Number of updates steps before two checkpoint saves if `save_strategy="steps"`. Should be an integer or a + float in range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps. + save_total_limit (`int`, *optional*): + If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in + `output_dir`. When `load_best_model_at_end` is enabled, the "best" checkpoint according to + `metric_for_best_model` will always be retained in addition to the most recent ones. For example, for + `save_total_limit=5` and `load_best_model_at_end`, the four last checkpoints will always be retained + alongside the best model. When `save_total_limit=1` and `load_best_model_at_end`, it is possible that two + checkpoints are saved: the last one and the best one (if they are different). + save_safetensors (`bool`, *optional*, defaults to `True`): + Use [safetensors](https://huggingface.co/docs/safetensors) saving and loading for state dicts instead of + default `torch.load` and `torch.save`. + save_on_each_node (`bool`, *optional*, defaults to `False`): + When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on + the main one. + + This should not be activated when the different nodes use the same storage as the files will be saved with + the same names for each node. + save_only_model (`bool`, *optional*, defaults to `False`): + When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state. + Note that when this is true, you won't be able to resume training from checkpoint. + This enables you to save storage by not storing the optimizer, scheduler & rng state. + You can only load the model using `from_pretrained` with this option set to `True`. + restore_callback_states_from_checkpoint (`bool`, *optional*, defaults to `False`): + Whether to restore the callback states from the checkpoint. If `True`, will override + callbacks passed to the `Trainer` if they exist in the checkpoint." + use_cpu (`bool`, *optional*, defaults to `False`): + Whether or not to use cpu. If set to False, we will use cuda or mps device if available. + seed (`int`, *optional*, defaults to 42): + Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the + [`~Trainer.model_init`] function to instantiate the model if it has some randomly initialized parameters. + data_seed (`int`, *optional*): + Random seed to be used with data samplers. If not set, random generators for data sampling will use the + same seed as `seed`. This can be used to ensure reproducibility of data sampling, independent of the model + seed. + jit_mode_eval (`bool`, *optional*, defaults to `False`): + Whether or not to use PyTorch jit trace for inference. + bf16 (`bool`, *optional*, defaults to `False`): + Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher + NVIDIA architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. + fp16 (`bool`, *optional*, defaults to `False`): + Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training. + fp16_opt_level (`str`, *optional*, defaults to 'O1'): + For `fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details on + the [Apex documentation](https://nvidia.github.io/apex/amp). + fp16_backend (`str`, *optional*, defaults to `"auto"`): + This argument is deprecated. Use `half_precision_backend` instead. + half_precision_backend (`str`, *optional*, defaults to `"auto"`): + The backend to use for mixed precision training. Must be one of `"auto", "apex", "cpu_amp"`. `"auto"` will + use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices will force the + requested backend. + bf16_full_eval (`bool`, *optional*, defaults to `False`): + Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm + metric values. + fp16_full_eval (`bool`, *optional*, defaults to `False`): + Whether to use full float16 evaluation instead of 32-bit. This will be faster and save memory but can harm + metric values. + tf32 (`bool`, *optional*): + Whether to enable the TF32 mode, available in Ampere and newer GPU architectures. The default value depends + on PyTorch's version default of `torch.backends.cuda.matmul.allow_tf32`. For more details please refer to + the [TF32](https://huggingface.co/docs/transformers/perf_train_gpu_one#tf32) documentation. This is an + experimental API and it may change. + local_rank (`int`, *optional*, defaults to -1): + Rank of the process during distributed training. + ddp_backend (`str`, *optional*): + The backend to use for distributed training. Must be one of `"nccl"`, `"mpi"`, `"ccl"`, `"gloo"`, `"hccl"`. + tpu_num_cores (`int`, *optional*): + When training on TPU, the number of TPU cores (automatically passed by launcher script). + dataloader_drop_last (`bool`, *optional*, defaults to `False`): + Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) + or not. + eval_steps (`int` or `float`, *optional*): + Number of update steps between two evaluations if `eval_strategy="steps"`. Will default to the same + value as `logging_steps` if not set. Should be an integer or a float in range `[0,1)`. If smaller than 1, + will be interpreted as ratio of total training steps. + dataloader_num_workers (`int`, *optional*, defaults to 0): + Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the + main process. + past_index (`int`, *optional*, defaults to -1): + Some models like [TransformerXL](../model_doc/transformerxl) or [XLNet](../model_doc/xlnet) can make use of + the past hidden states for their predictions. If this argument is set to a positive int, the `Trainer` will + use the corresponding output (usually index 2) as the past state and feed it to the model at the next + training step under the keyword argument `mems`. + run_name (`str`, *optional*, defaults to `output_dir`): + A descriptor for the run. Typically used for [trackio](https://github.com/gradio-app/trackio), + [wandb](https://www.wandb.com/), [mlflow](https://www.mlflow.org/), [comet](https://www.comet.com/site) and + [swanlab](https://swanlab.cn) logging. If not specified, will be the same as `output_dir`. + disable_tqdm (`bool`, *optional*): + Whether or not to disable the tqdm progress bars and table of metrics produced by + [`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is + set to warn or lower (default), `False` otherwise. + remove_unused_columns (`bool`, *optional*, defaults to `True`): + Whether or not to automatically remove the columns unused by the model forward method. + label_names (`list[str]`, *optional*): + The list of keys in your dictionary of inputs that correspond to the labels. + + Will eventually default to the list of argument names accepted by the model that contain the word "label", + except if the model used is one of the `XxxForQuestionAnswering` in which case it will also include the + `["start_positions", "end_positions"]` keys. + + You should only specify `label_names` if you're using custom label names or if your model's `forward` consumes multiple label tensors (e.g., extractive QA). + load_best_model_at_end (`bool`, *optional*, defaults to `False`): + Whether or not to load the best model found during training at the end of training. When this option is + enabled, the best checkpoint will always be saved. See + [`save_total_limit`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.save_total_limit) + for more. + + + + When set to `True`, the parameters `save_strategy` needs to be the same as `eval_strategy`, and in + the case it is "steps", `save_steps` must be a round multiple of `eval_steps`. + + + + metric_for_best_model (`str`, *optional*): + Use in conjunction with `load_best_model_at_end` to specify the metric to use to compare two different + models. Must be the name of a metric returned by the evaluation with or without the prefix `"eval_"`. + + If not specified, this will default to `"loss"` when either `load_best_model_at_end == True` + or `lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU` (to use the evaluation loss). + + If you set this value, `greater_is_better` will default to `True` unless the name ends with "loss". + Don't forget to set it to `False` if your metric is better when lower. + greater_is_better (`bool`, *optional*): + Use in conjunction with `load_best_model_at_end` and `metric_for_best_model` to specify if better models + should have a greater metric or not. Will default to: + + - `True` if `metric_for_best_model` is set to a value that doesn't end in `"loss"`. + - `False` if `metric_for_best_model` is not set, or set to a value that ends in `"loss"`. + ignore_data_skip (`bool`, *optional*, defaults to `False`): + When resuming training, whether or not to skip the epochs and batches to get the data loading at the same + stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step + can take a long time) but will not yield the same results as the interrupted training would have. + fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `None`): + Use PyTorch Distributed Parallel Training (in distributed training only). + + A list of options along the following: + + - `"full_shard"`: Shard parameters, gradients and optimizer states. + - `"shard_grad_op"`: Shard optimizer states and gradients. + - `"hybrid_shard"`: Apply `FULL_SHARD` within a node, and replicate parameters across nodes. + - `"hybrid_shard_zero2"`: Apply `SHARD_GRAD_OP` within a node, and replicate parameters across nodes. + - `"offload"`: Offload parameters and gradients to CPUs (only compatible with `"full_shard"` and + `"shard_grad_op"`). + - `"auto_wrap"`: Automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`. + fsdp_config (`str` or `dict`, *optional*): + Config to be used with fsdp (Pytorch Distributed Parallel Training). The value is either a location of + fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`. + + A List of config and its options: + - min_num_params (`int`, *optional*, defaults to `0`): + FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is + passed). + - transformer_layer_cls_to_wrap (`list[str]`, *optional*): + List of transformer layer class names (case-sensitive) to wrap, e.g, `BertLayer`, `GPTJBlock`, + `T5Block` .... (useful only when `fsdp` flag is passed). + - backward_prefetch (`str`, *optional*) + FSDP's backward prefetch mode. Controls when to prefetch next set of parameters (useful only when + `fsdp` field is passed). + + A list of options along the following: + + - `"backward_pre"` : Prefetches the next set of parameters before the current set of parameter's + gradient computation. + - `"backward_post"` : This prefetches the next set of parameters after the current set of + parameter's gradient computation. + - forward_prefetch (`bool`, *optional*, defaults to `False`) + FSDP's forward prefetch mode (useful only when `fsdp` field is passed). + If `"True"`, then FSDP explicitly prefetches the next upcoming all-gather while executing in the + forward pass. + - limit_all_gathers (`bool`, *optional*, defaults to `False`) + FSDP's limit_all_gathers (useful only when `fsdp` field is passed). + If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight + all-gathers. + - use_orig_params (`bool`, *optional*, defaults to `True`) + If `"True"`, allows non-uniform `requires_grad` during init, which means support for interspersed + frozen and trainable parameters. Useful in cases such as parameter-efficient fine-tuning. Please + refer this + [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019 + - sync_module_states (`bool`, *optional*, defaults to `True`) + If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to + ensure they are the same across all ranks after initialization + - cpu_ram_efficient_loading (`bool`, *optional*, defaults to `False`) + If `"True"`, only the first process loads the pretrained model checkpoint while all other processes + have empty weights. When this setting as `"True"`, `sync_module_states` also must to be `"True"`, + otherwise all the processes except the main process would have random weights leading to unexpected + behaviour during training. + - activation_checkpointing (`bool`, *optional*, defaults to `False`): + If `"True"`, activation checkpointing is a technique to reduce memory usage by clearing activations of + certain layers and recomputing them during a backward pass. Effectively, this trades extra + computation time for reduced memory usage. + - xla (`bool`, *optional*, defaults to `False`): + Whether to use PyTorch/XLA Fully Sharded Data Parallel Training. This is an experimental feature + and its API may evolve in the future. + - xla_fsdp_settings (`dict`, *optional*) + The value is a dictionary which stores the XLA FSDP wrapping parameters. + + For a complete list of options, please see [here]( + https://github.com/pytorch/xla/blob/master/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py). + - xla_fsdp_grad_ckpt (`bool`, *optional*, defaults to `False`): + Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be + used when the xla flag is set to true, and an auto wrapping policy is specified through + fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap. + deepspeed (`str` or `dict`, *optional*): + Use [Deepspeed](https://github.com/deepspeedai/DeepSpeed). This is an experimental feature and its API may + evolve in the future. The value is either the location of DeepSpeed json config file (e.g., + `ds_config.json`) or an already loaded json file as a `dict`" + + + If enabling any Zero-init, make sure that your model is not initialized until + *after* initializing the `TrainingArguments`, else it will not be applied. + + + accelerator_config (`str`, `dict`, or `AcceleratorConfig`, *optional*): + Config to be used with the internal `Accelerator` implementation. The value is either a location of + accelerator json config file (e.g., `accelerator_config.json`), an already loaded json file as `dict`, + or an instance of [`~trainer_pt_utils.AcceleratorConfig`]. + + A list of config and its options: + - split_batches (`bool`, *optional*, defaults to `False`): + Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If + `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a + round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set + in your script multiplied by the number of processes. + - dispatch_batches (`bool`, *optional*): + If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process + and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose + underlying dataset is an `IterableDataset`, `False` otherwise. + - even_batches (`bool`, *optional*, defaults to `True`): + If set to `True`, in cases where the total batch size across all processes does not exactly divide the + dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among + all workers. + - use_seedable_sampler (`bool`, *optional*, defaults to `True`): + Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`]). Ensures + training results are fully reproducible using a different sampling technique. While seed-to-seed results + may differ, on average the differences are negligible when using multiple different seeds to compare. Should + also be ran with [`~utils.set_seed`] for the best results. + - use_configured_state (`bool`, *optional*, defaults to `False`): + Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`. + If `True`, an `Accelerator` or `PartialState` must be initialized. Note that by doing so, this could lead to issues + with hyperparameter tuning. + parallelism_config (`ParallelismConfig`, *optional*): + Parallelism configuration for the training run. Requires Accelerate `1.10.1` + label_smoothing_factor (`float`, *optional*, defaults to 0.0): + The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded + labels are changed from 0s and 1s to `label_smoothing_factor/num_labels` and `1 - label_smoothing_factor + + label_smoothing_factor/num_labels` respectively. + debug (`str` or list of [`~debug_utils.DebugOption`], *optional*, defaults to `""`): + Enable one or more debug features. This is an experimental feature. + + Possible options are: + + - `"underflow_overflow"`: detects overflow in model's input/outputs and reports the last frames that led to + the event + - `"tpu_metrics_debug"`: print debug metrics on TPU + + The options should be separated by whitespaces. + optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch"` (for torch>=2.8 `"adamw_torch_fused"`)): + The optimizer to use, such as "adamw_torch", "adamw_torch_fused", "adamw_apex_fused", "adamw_anyprecision", + "adafactor". See `OptimizerNames` in [training_args.py](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py) + for a full list of optimizers. + optim_args (`str`, *optional*): + Optional arguments that are supplied to optimizers such as AnyPrecisionAdamW, AdEMAMix, and GaLore. + group_by_length (`bool`, *optional*, defaults to `False`): + Whether or not to group together samples of roughly the same length in the training dataset (to minimize + padding applied and be more efficient). Only useful if applying dynamic padding. + length_column_name (`str`, *optional*, defaults to `"length"`): + Column name for precomputed lengths. If the column exists, grouping by length will use these values rather + than computing them on train startup. Ignored unless `group_by_length` is `True` and the dataset is an + instance of `Dataset`. + report_to (`str` or `list[str]`, *optional*, defaults to `"all"`): + The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`, + `"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`, `"neptune"`, + `"swanlab"`, `"tensorboard"`, `"trackio"` and `"wandb"`. Use `"all"` to report to all integrations + installed, `"none"` for no integrations. + project (`str`, *optional*, defaults to `"huggingface"`): + The name of the project to use for logging. Currently, only used by Trackio. + trackio_space_id (`str` or `None`, *optional*, defaults to `"trackio"`): + The Hugging Face Space ID to deploy to when using Trackio. Should be a complete Space name like + `'username/reponame'` or `'orgname/reponame' `, or just `'reponame'` in which case the Space will be + created in the currently-logged-in Hugging Face user's namespace. If `None`, will log to a local directory. + Note that this Space will be public unless you set `hub_private_repo=True` or your organization's default + is to create private Spaces." + ddp_find_unused_parameters (`bool`, *optional*): + When using distributed training, the value of the flag `find_unused_parameters` passed to + `DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise. + ddp_bucket_cap_mb (`int`, *optional*): + When using distributed training, the value of the flag `bucket_cap_mb` passed to `DistributedDataParallel`. + ddp_broadcast_buffers (`bool`, *optional*): + When using distributed training, the value of the flag `broadcast_buffers` passed to + `DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise. + dataloader_pin_memory (`bool`, *optional*, defaults to `True`): + Whether you want to pin memory in data loaders or not. Will default to `True`. + dataloader_persistent_workers (`bool`, *optional*, defaults to `False`): + If True, the data loader will not shut down the worker processes after a dataset has been consumed once. + This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will + increase RAM usage. Will default to `False`. + dataloader_prefetch_factor (`int`, *optional*): + Number of batches loaded in advance by each worker. + 2 means there will be a total of 2 * num_workers batches prefetched across all workers. + skip_memory_metrics (`bool`, *optional*, defaults to `True`): + Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows + down the training and evaluation speed. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push the model to the Hub every time the model is saved. If this is activated, + `output_dir` will begin a git directory synced with the repo (determined by `hub_model_id`) and the content + will be pushed each time a save is triggered (depending on your `save_strategy`). Calling + [`~Trainer.save_model`] will also trigger a push. + + + + If `output_dir` exists, it needs to be a local clone of the repository to which the [`Trainer`] will be + pushed. + + + + resume_from_checkpoint (`str`, *optional*): + The path to a folder with a valid checkpoint for your model. This argument is not directly used by + [`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example + scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details. + hub_model_id (`str`, *optional*): + The name of the repository to keep in sync with the local *output_dir*. It can be a simple model ID in + which case the model will be pushed in your namespace. Otherwise it should be the whole repository name, + for instance `"user_name/model"`, which allows you to push to an organization you are a member of with + `"organization_name/model"`. Will default to `user_name/output_dir_name` with *output_dir_name* being the + name of `output_dir`. + + Will default to the name of `output_dir`. + hub_strategy (`str` or [`~trainer_utils.HubStrategy`], *optional*, defaults to `"every_save"`): + Defines the scope of what is pushed to the Hub and when. Possible values are: + + - `"end"`: push the model, its configuration, the processing class e.g. tokenizer (if passed along to the [`Trainer`]) and a + draft of a model card when the [`~Trainer.save_model`] method is called. + - `"every_save"`: push the model, its configuration, the processing class e.g. tokenizer (if passed along to the [`Trainer`]) and + a draft of a model card each time there is a model save. The pushes are asynchronous to not block + training, and in case the save are very frequent, a new push is only attempted if the previous one is + finished. A last push is made with the final model at the end of training. + - `"checkpoint"`: like `"every_save"` but the latest checkpoint is also pushed in a subfolder named + last-checkpoint, allowing you to resume training easily with + `trainer.train(resume_from_checkpoint="last-checkpoint")`. + - `"all_checkpoints"`: like `"checkpoint"` but all checkpoints are pushed like they appear in the output + folder (so you will get one checkpoint folder per folder in your final repository) + + hub_token (`str`, *optional*): + The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with + `hf auth login`. + hub_private_repo (`bool`, *optional*): + Whether to make the repo private. If `None` (default), the repo will be public unless the organization's + default is private. This value is ignored if the repo already exists. If reporting to Trackio with + deployment to Hugging Face Spaces enabled, the same logic determines whether the Space is private. + hub_always_push (`bool`, *optional*, defaults to `False`): + Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not finished. + hub_revision (`str`, *optional*): + The revision to use when pushing to the Hub. Can be a branch name, a tag, or a commit hash. + gradient_checkpointing (`bool`, *optional*, defaults to `False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`): + Key word arguments to be passed to the `gradient_checkpointing_enable` method. + include_inputs_for_metrics (`bool`, *optional*, defaults to `False`): + This argument is deprecated. Use `include_for_metrics` instead, e.g, `include_for_metrics = ["inputs"]`. + include_for_metrics (`list[str]`, *optional*, defaults to `[]`): + Include additional data in the `compute_metrics` function if needed for metrics computation. + Possible options to add to `include_for_metrics` list: + - `"inputs"`: Input data passed to the model, intended for calculating input dependent metrics. + - `"loss"`: Loss values computed during evaluation, intended for calculating loss dependent metrics. + eval_do_concat_batches (`bool`, *optional*, defaults to `True`): + Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`, + will instead store them as lists, with each batch kept separate. + auto_find_batch_size (`bool`, *optional*, defaults to `False`) + Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding + CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`) + full_determinism (`bool`, *optional*, defaults to `False`) + If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in + distributed training. Important: this will negatively impact the performance, so only use it for debugging. + torchdynamo (`str`, *optional*): + If set, the backend compiler for TorchDynamo. Possible choices are `"eager"`, `"aot_eager"`, `"inductor"`, + `"nvfuser"`, `"aot_nvfuser"`, `"aot_cudagraphs"`, `"ofi"`, `"fx2trt"`, `"onnxrt"` and `"ipex"`. + ray_scope (`str`, *optional*, defaults to `"last"`): + The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray will + then use the last checkpoint of all trials, compare those, and select the best one. However, other options + are also available. See the [Ray documentation]( + https://docs.ray.io/en/latest/tune/api_docs/analysis.html#ray.tune.ExperimentAnalysis.get_best_trial) for + more options. + ddp_timeout (`int`, *optional*, defaults to 1800): + The timeout for `torch.distributed.init_process_group` calls, used to avoid GPU socket timeouts when + performing slow operations in distributed runnings. Please refer the [PyTorch documentation] + (https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more + information. + use_mps_device (`bool`, *optional*, defaults to `False`): + This argument is deprecated.`mps` device will be used if it is available similar to `cuda` device. + torch_compile (`bool`, *optional*, defaults to `False`): + Whether or not to compile the model using PyTorch 2.0 + [`torch.compile`](https://pytorch.org/get-started/pytorch-2.0/). + + This will use the best defaults for the [`torch.compile` + API](https://pytorch.org/docs/stable/generated/torch.compile.html?highlight=torch+compile#torch.compile). + You can customize the defaults with the argument `torch_compile_backend` and `torch_compile_mode` but we + don't guarantee any of them will work as the support is progressively rolled in in PyTorch. + + This flag and the whole compile API is experimental and subject to change in future releases. + torch_compile_backend (`str`, *optional*): + The backend to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`. + + Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions. + + This flag is experimental and subject to change in future releases. + torch_compile_mode (`str`, *optional*): + The mode to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`. + + Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions. + + This flag is experimental and subject to change in future releases. + include_tokens_per_second (`bool`, *optional*, defaults to `False`): + Whether or not to compute the number of tokens per second per device for training speed metrics. + + This will iterate over the entire training dataloader once beforehand, + and will slow down the entire process. + + include_num_input_tokens_seen (`bool`, *optional*): + Whether or not to track the number of input tokens seen throughout training. + + May be slower in distributed training as gather operations must be called. + + neftune_noise_alpha (`Optional[float]`): + If not `None`, this will activate NEFTune noise embeddings. This can drastically improve model performance + for instruction fine-tuning. Check out the [original paper](https://huggingface.co/papers/2310.05914) and the + [original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also + `PeftModel` from peft. The original paper used values in the range [5.0, 15.0]. + optim_target_modules (`Union[str, list[str]]`, *optional*): + The target modules to optimize, i.e. the module names that you would like to train. + Currently used for the GaLore algorithm (https://huggingface.co/papers/2403.03507) and APOLLO algorithm (https://huggingface.co/papers/2412.05270). + See GaLore implementation (https://github.com/jiaweizzhao/GaLore) and APOLLO implementation (https://github.com/zhuhanqing/APOLLO) for more details. + You need to make sure to pass a valid GaLore or APOLLO optimizer, e.g., one of: "apollo_adamw", "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules only. + + batch_eval_metrics (`bool`, *optional*, defaults to `False`): + If set to `True`, evaluation will call compute_metrics at the end of each batch to accumulate statistics + rather than saving all eval logits in memory. When set to `True`, you must pass a compute_metrics function + that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global + summary statistics from the batch-level summary statistics you've accumulated over the evaluation set. + + eval_on_start (`bool`, *optional*, defaults to `False`): + Whether to perform a evaluation step (sanity check) before the training to ensure the validation steps works correctly. + + eval_use_gather_object (`bool`, *optional*, defaults to `False`): + Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices. This should only be enabled if users are not just returning tensors, and this is actively discouraged by PyTorch. + + use_liger_kernel (`bool`, *optional*, defaults to `False`): + Whether enable [Liger](https://github.com/linkedin/Liger-Kernel) Kernel for LLM model training. + It can effectively increase multi-GPU training throughput by ~20% and reduces memory usage by ~60%, works out of the box with + flash attention, PyTorch FSDP, and Microsoft DeepSpeed. Currently, it supports llama, mistral, mixtral and gemma models. + + liger_kernel_config (`Optional[dict]`, *optional*): + Configuration to be used for Liger Kernel. When use_liger_kernel=True, this dict is passed as keyword arguments to the + `_apply_liger_kernel_to_instance` function, which specifies which kernels to apply. Available options vary by model but typically + include: 'rope', 'swiglu', 'cross_entropy', 'fused_linear_cross_entropy', 'rms_norm', etc. If `None`, use the default kernel configurations. + + average_tokens_across_devices (`bool`, *optional*, defaults to `True`): + Whether or not to average tokens across devices. If enabled, will use all_reduce to synchronize + num_tokens_in_batch for precise loss calculation. Reference: + https://github.com/huggingface/transformers/issues/34242 + """ + + # Sometimes users will pass in a `str` repr of a dict in the CLI + # We need to track what fields those can be. Each time a new arg + # has a dict type, it must be added to this list. + # Important: These should be typed with Optional[Union[dict,str,...]] + _VALID_DICT_FIELDS = [ + "accelerator_config", + "fsdp_config", + "deepspeed", + "gradient_checkpointing_kwargs", + "lr_scheduler_kwargs", + ] + framework = "pt" + + output_dir: Optional[str] = field( + default=None, + metadata={ + "help": "The output directory where the model predictions and checkpoints will be written. Defaults to 'trainer_output' if not provided." + }, + ) + overwrite_output_dir: bool = field( + default=False, + metadata={ + "help": ( + "Overwrite the content of the output directory. " + "Use this to continue training if output_dir points to a checkpoint directory." + ) + }, + ) + + do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) + do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) + do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) + eval_strategy: Union[IntervalStrategy, str] = field( + default="no", + metadata={"help": "The evaluation strategy to use."}, + ) + prediction_loss_only: bool = field( + default=False, + metadata={"help": "When performing evaluation and predictions, only returns the loss."}, + ) + + per_device_train_batch_size: int = field( + default=8, metadata={"help": "Batch size per device accelerator core/CPU for training."} + ) + per_device_eval_batch_size: int = field( + default=8, metadata={"help": "Batch size per device accelerator core/CPU for evaluation."} + ) + + per_gpu_train_batch_size: Optional[int] = field( + default=None, + metadata={ + "help": ( + "Deprecated, the use of `--per_device_train_batch_size` is preferred. " + "Batch size per GPU/TPU core/CPU for training." + ) + }, + ) + per_gpu_eval_batch_size: Optional[int] = field( + default=None, + metadata={ + "help": ( + "Deprecated, the use of `--per_device_eval_batch_size` is preferred. " + "Batch size per GPU/TPU core/CPU for evaluation." + ) + }, + ) + + gradient_accumulation_steps: int = field( + default=1, + metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, + ) + eval_accumulation_steps: Optional[int] = field( + default=None, + metadata={"help": "Number of predictions steps to accumulate before moving the tensors to the CPU."}, + ) + + eval_delay: float = field( + default=0, + metadata={ + "help": ( + "Number of epochs or steps to wait for before the first evaluation can be performed, depending on the" + " eval_strategy." + ) + }, + ) + + torch_empty_cache_steps: Optional[int] = field( + default=None, + metadata={ + "help": "Number of steps to wait before calling `torch..empty_cache()`." + "This can help avoid CUDA out-of-memory errors by lowering peak VRAM usage at a cost of about [10% slower performance](https://github.com/huggingface/transformers/issues/31372)." + "If left unset or set to None, cache will not be emptied." + }, + ) + + learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."}) + weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."}) + adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}) + adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}) + adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}) + max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."}) + + num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."}) + max_steps: int = field( + default=-1, + metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."}, + ) + lr_scheduler_type: Union[SchedulerType, str] = field( + default="linear", + metadata={"help": "The scheduler type to use."}, + ) + lr_scheduler_kwargs: Optional[Union[dict, str]] = field( + default=None, + metadata={ + "help": ( + "Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts." + ) + }, + ) + warmup_ratio: float = field( + default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."} + ) + warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) + + log_level: str = field( + default="passive", + metadata={ + "help": ( + "Logger log level to use on the main node. Possible choices are the log levels as strings: 'debug'," + " 'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and" + " lets the application set the level. Defaults to 'passive'." + ), + "choices": trainer_log_levels.keys(), + }, + ) + log_level_replica: str = field( + default="warning", + metadata={ + "help": "Logger log level to use on replica nodes. Same choices and defaults as ``log_level``", + "choices": trainer_log_levels.keys(), + }, + ) + log_on_each_node: bool = field( + default=True, + metadata={ + "help": ( + "When doing a multinode distributed training, whether to log once per node or just once on the main" + " node." + ) + }, + ) + logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."}) + logging_strategy: Union[IntervalStrategy, str] = field( + default="steps", + metadata={"help": "The logging strategy to use."}, + ) + logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"}) + logging_steps: float = field( + default=500, + metadata={ + "help": ( + "Log every X updates steps. Should be an integer or a float in range `[0,1)`. " + "If smaller than 1, will be interpreted as ratio of total training steps." + ) + }, + ) + logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."}) + save_strategy: Union[SaveStrategy, str] = field( + default="steps", + metadata={"help": "The checkpoint save strategy to use."}, + ) + save_steps: float = field( + default=500, + metadata={ + "help": ( + "Save checkpoint every X updates steps. Should be an integer or a float in range `[0,1)`. " + "If smaller than 1, will be interpreted as ratio of total training steps." + ) + }, + ) + save_total_limit: Optional[int] = field( + default=None, + metadata={ + "help": ( + "If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in" + " `output_dir`. When `load_best_model_at_end` is enabled, the 'best' checkpoint according to" + " `metric_for_best_model` will always be retained in addition to the most recent ones. For example," + " for `save_total_limit=5` and `load_best_model_at_end=True`, the four last checkpoints will always be" + " retained alongside the best model. When `save_total_limit=1` and `load_best_model_at_end=True`," + " it is possible that two checkpoints are saved: the last one and the best one (if they are different)." + " Default is unlimited checkpoints" + ) + }, + ) + save_safetensors: bool = field( + default=True, + metadata={ + "help": "Use safetensors saving and loading for state dicts instead of default torch.load and torch.save." + }, + ) + save_on_each_node: bool = field( + default=False, + metadata={ + "help": ( + "When doing multi-node distributed training, whether to save models and checkpoints on each node, or" + " only on the main one" + ) + }, + ) + save_only_model: bool = field( + default=False, + metadata={ + "help": ( + "When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state." + "Note that when this is true, you won't be able to resume training from checkpoint." + "This enables you to save storage by not storing the optimizer, scheduler & rng state." + "You can only load the model using from_pretrained with this option set to True." + ) + }, + ) + restore_callback_states_from_checkpoint: bool = field( + default=False, + metadata={ + "help": "Whether to restore the callback states from the checkpoint. If `True`, will override callbacks passed to the `Trainer` if they exist in the checkpoint." + }, + ) + no_cuda: bool = field( + default=False, + metadata={"help": "This argument is deprecated. It will be removed in version 5.0 of 🤗 Transformers."}, + ) + use_cpu: bool = field( + default=False, + metadata={ + "help": "Whether or not to use cpu. If left to False, we will use the available torch device/backend (cuda/mps/xpu/hpu etc.)" + }, + ) + use_mps_device: bool = field( + default=False, + metadata={ + "help": "This argument is deprecated. `mps` device will be used if available similar to `cuda` device." + " It will be removed in version 5.0 of 🤗 Transformers" + }, + ) + seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) + data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."}) + jit_mode_eval: bool = field( + default=False, metadata={"help": "Whether or not to use PyTorch jit trace for inference"} + ) + bf16: bool = field( + default=False, + metadata={ + "help": ( + "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA" + " architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change." + ) + }, + ) + fp16: bool = field( + default=False, + metadata={"help": "Whether to use fp16 (mixed) precision instead of 32-bit"}, + ) + fp16_opt_level: str = field( + default="O1", + metadata={ + "help": ( + "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. " + "See details at https://nvidia.github.io/apex/amp.html" + ) + }, + ) + half_precision_backend: str = field( + default="auto", + metadata={ + "help": "The backend to be used for half precision.", + "choices": ["auto", "apex", "cpu_amp"], + }, + ) + bf16_full_eval: bool = field( + default=False, + metadata={ + "help": ( + "Whether to use full bfloat16 evaluation instead of 32-bit. This is an experimental API and it may" + " change." + ) + }, + ) + fp16_full_eval: bool = field( + default=False, + metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"}, + ) + tf32: Optional[bool] = field( + default=None, + metadata={ + "help": ( + "Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental" + " API and it may change." + ) + }, + ) + local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"}) + ddp_backend: Optional[str] = field( + default=None, + metadata={ + "help": "The backend to be used for distributed training", + "choices": ["nccl", "gloo", "mpi", "ccl", "hccl", "cncl", "mccl"], + }, + ) + tpu_num_cores: Optional[int] = field( + default=None, metadata={"help": "TPU: Number of TPU cores (automatically passed by launcher script)"} + ) + tpu_metrics_debug: bool = field( + default=False, + metadata={ + "help": ( + "Deprecated, the use of `--debug tpu_metrics_debug` is preferred. TPU: Whether to print debug metrics" + ) + }, + ) + debug: Union[str, list[DebugOption]] = field( + default="", + metadata={ + "help": ( + "Whether or not to enable debug mode. Current options: " + "`underflow_overflow` (Detect underflow and overflow in activations and weights), " + "`tpu_metrics_debug` (print debug metrics on TPU)." + ) + }, + ) + + dataloader_drop_last: bool = field( + default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} + ) + eval_steps: Optional[float] = field( + default=None, + metadata={ + "help": ( + "Run an evaluation every X steps. Should be an integer or a float in range `[0,1)`. " + "If smaller than 1, will be interpreted as ratio of total training steps." + ) + }, + ) + dataloader_num_workers: int = field( + default=0, + metadata={ + "help": ( + "Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded" + " in the main process." + ) + }, + ) + dataloader_prefetch_factor: Optional[int] = field( + default=None, + metadata={ + "help": ( + "Number of batches loaded in advance by each worker. " + "2 means there will be a total of 2 * num_workers batches prefetched across all workers. " + ) + }, + ) + past_index: int = field( + default=-1, + metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."}, + ) + + run_name: Optional[str] = field( + default=None, + metadata={ + "help": ( + "An optional descriptor for the run. Notably used for trackio, wandb, mlflow comet and swanlab " + "logging." + ) + }, + ) + disable_tqdm: Optional[bool] = field( + default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."} + ) + + remove_unused_columns: bool = field( + default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."} + ) + label_names: Optional[list[str]] = field( + default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."} + ) + load_best_model_at_end: bool = field( + default=False, + metadata={ + "help": ( + "Whether or not to load the best model found during training at the end of training. When this option" + " is enabled, the best checkpoint will always be saved. See `save_total_limit` for more." + ) + }, + ) + metric_for_best_model: Optional[str] = field( + default=None, metadata={"help": "The metric to use to compare two different models."} + ) + greater_is_better: Optional[bool] = field( + default=None, metadata={"help": "Whether the `metric_for_best_model` should be maximized or not."} + ) + ignore_data_skip: bool = field( + default=False, + metadata={ + "help": ( + "When resuming training, whether or not to skip the first epochs and batches to get to the same" + " training data." + ) + }, + ) + fsdp: Optional[Union[list[FSDPOption], str]] = field( + default=None, + metadata={ + "help": ( + "Whether or not to use PyTorch Fully Sharded Data Parallel (FSDP) training (in distributed training" + " only). The base option should be `full_shard`, `shard_grad_op` or `no_shard` and you can add" + " CPU-offload to `full_shard` or `shard_grad_op` like this: full_shard offload` or `shard_grad_op" + " offload`. You can add auto-wrap to `full_shard` or `shard_grad_op` with the same syntax: full_shard" + " auto_wrap` or `shard_grad_op auto_wrap`." + ), + }, + ) + fsdp_min_num_params: int = field( + default=0, + metadata={ + "help": ( + "This parameter is deprecated. FSDP's minimum number of parameters for Default Auto Wrapping. (useful" + " only when `fsdp` field is passed)." + ) + }, + ) + fsdp_config: Optional[Union[dict[str, Any], str]] = field( + default=None, + metadata={ + "help": ( + "Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a " + "fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`." + ) + }, + ) + fsdp_transformer_layer_cls_to_wrap: Optional[str] = field( + default=None, + metadata={ + "help": ( + "This parameter is deprecated. Transformer layer class name (case-sensitive) to wrap, e.g," + " `BertLayer`, `GPTJBlock`, `T5Block` .... (useful only when `fsdp` flag is passed)." + ) + }, + ) + accelerator_config: Optional[Union[dict, str]] = field( + default=None, + metadata={ + "help": ( + "Config to be used with the internal Accelerator object initialization. The value is either a " + "accelerator json config file (e.g., `accelerator_config.json`) or an already loaded json file as `dict`." + ) + }, + ) + parallelism_config: Optional[ParallelismConfig] = field( + default=None, + metadata={"help": ("Parallelism configuration for the training run. Requires Accelerate `1.10.1`")}, + ) + deepspeed: Optional[Union[dict, str]] = field( + default=None, + metadata={ + "help": ( + "Enable deepspeed and pass the path to deepspeed json config file (e.g. `ds_config.json`) or an already" + " loaded json file as a dict" + ) + }, + ) + label_smoothing_factor: float = field( + default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} + ) + + default_optim = "adamw_torch" + if is_torch_available(): + from .pytorch_utils import is_torch_greater_or_equal_than_2_8 + + if is_torch_greater_or_equal_than_2_8: + default_optim = "adamw_torch_fused" + optim: Union[OptimizerNames, str] = field( + default=default_optim, + metadata={"help": "The optimizer to use."}, + ) + optim_args: Optional[str] = field(default=None, metadata={"help": "Optional arguments to supply to optimizer."}) + adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."}) + group_by_length: bool = field( + default=False, + metadata={"help": "Whether or not to group samples of roughly the same length together when batching."}, + ) + length_column_name: str = field( + default="length", + metadata={"help": "Column name with precomputed lengths to use when grouping by length."}, + ) + report_to: Union[None, str, list[str]] = field( + default=None, metadata={"help": "The list of integrations to report the results and logs to."} + ) + project: str = field( + default="huggingface", + metadata={"help": "The name of the project to use for logging. Currenly, only used by Trackio."}, + ) + trackio_space_id: Optional[str] = field( + default="trackio", + metadata={ + "help": "The Hugging Face Space ID to deploy to when using Trackio. Should be a complete Space name like " + "'username/reponame' or 'orgname/reponame', or just 'reponame' in which case the Space will be created in " + "the currently-logged-in Hugging Face user's namespace. If `None`, will log to a local directory. Note " + "that this Space will be public unless you set `hub_private_repo=True` or your organization's " + "default is to create private Spaces." + }, + ) + ddp_find_unused_parameters: Optional[bool] = field( + default=None, + metadata={ + "help": ( + "When using distributed training, the value of the flag `find_unused_parameters` passed to " + "`DistributedDataParallel`." + ) + }, + ) + ddp_bucket_cap_mb: Optional[int] = field( + default=None, + metadata={ + "help": ( + "When using distributed training, the value of the flag `bucket_cap_mb` passed to " + "`DistributedDataParallel`." + ) + }, + ) + ddp_broadcast_buffers: Optional[bool] = field( + default=None, + metadata={ + "help": ( + "When using distributed training, the value of the flag `broadcast_buffers` passed to " + "`DistributedDataParallel`." + ) + }, + ) + dataloader_pin_memory: bool = field( + default=True, metadata={"help": "Whether or not to pin memory for DataLoader."} + ) + dataloader_persistent_workers: bool = field( + default=False, + metadata={ + "help": "If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will increase RAM usage." + }, + ) + skip_memory_metrics: bool = field( + default=True, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."} + ) + use_legacy_prediction_loop: bool = field( + default=False, metadata={"help": "Whether or not to use the legacy prediction_loop in the Trainer."} + ) + push_to_hub: bool = field( + default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."} + ) + resume_from_checkpoint: Optional[str] = field( + default=None, + metadata={"help": "The path to a folder with a valid checkpoint for your model."}, + ) + hub_model_id: Optional[str] = field( + default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} + ) + hub_strategy: Union[HubStrategy, str] = field( + default="every_save", + metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."}, + ) + hub_token: Optional[str] = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) + hub_private_repo: Optional[bool] = field( + default=None, + metadata={ + "help": "Whether to make the repo private. If `None` (default), the repo will be public unless the " + "organization's default is private. This value is ignored if the repo already exists. If reporting to " + "Trackio with deployment to Hugging Face Spaces enabled, the same logic determines whether the Space is " + "private." + }, + ) + hub_always_push: bool = field( + default=False, + metadata={"help": "Unless `True`, the Trainer will skip pushes if the previous one wasn't finished yet."}, + ) + hub_revision: Optional[str] = field( + default=None, + metadata={ + "help": "The revision to use when pushing to the Hub. Can be a branch name, a tag, or a commit hash." + }, + ) + gradient_checkpointing: bool = field( + default=False, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) + gradient_checkpointing_kwargs: Optional[Union[dict[str, Any], str]] = field( + default=None, + metadata={ + "help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`." + }, + ) + include_inputs_for_metrics: bool = field( + default=False, + metadata={ + "help": "This argument is deprecated and will be removed in version 5 of 🤗 Transformers. Use `include_for_metrics` instead." + }, + ) + include_for_metrics: list[str] = field( + default_factory=list, + metadata={ + "help": "List of strings to specify additional data to include in the `compute_metrics` function." + "Options: 'inputs', 'loss'." + }, + ) + eval_do_concat_batches: bool = field( + default=True, + metadata={ + "help": "Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`, will instead store them as lists, with each batch kept separate." + }, + ) + # Deprecated arguments + fp16_backend: str = field( + default="auto", + metadata={ + "help": "Deprecated. Use half_precision_backend instead", + "choices": ["auto", "apex", "cpu_amp"], + }, + ) + push_to_hub_model_id: Optional[str] = field( + default=None, metadata={"help": "The name of the repository to which push the `Trainer`."} + ) + push_to_hub_organization: Optional[str] = field( + default=None, metadata={"help": "The name of the organization in with to which push the `Trainer`."} + ) + push_to_hub_token: Optional[str] = field( + default=None, metadata={"help": "The token to use to push to the Model Hub."} + ) + _n_gpu: int = field(init=False, repr=False, default=-1) + mp_parameters: str = field( + default="", + metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in Trainer"}, + ) + + auto_find_batch_size: bool = field( + default=False, + metadata={ + "help": ( + "Whether to automatically decrease the batch size in half and rerun the training loop again each time" + " a CUDA Out-of-Memory was reached" + ) + }, + ) + full_determinism: bool = field( + default=False, + metadata={ + "help": ( + "Whether to call enable_full_determinism instead of set_seed for reproducibility in distributed" + " training. Important: this will negatively impact the performance, so only use it for debugging." + ) + }, + ) + torchdynamo: Optional[str] = field( + default=None, + metadata={ + "help": "This argument is deprecated, use `--torch_compile_backend` instead.", + }, + ) + ray_scope: Optional[str] = field( + default="last", + metadata={ + "help": ( + 'The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray' + " will then use the last checkpoint of all trials, compare those, and select the best one. However," + " other options are also available. See the Ray documentation" + " (https://docs.ray.io/en/latest/tune/api_docs/analysis.html" + "#ray.tune.ExperimentAnalysis.get_best_trial)" + " for more options." + ) + }, + ) + ddp_timeout: int = field( + default=1800, + metadata={ + "help": "Overrides the default timeout for distributed training (value should be given in seconds)." + }, + ) + torch_compile: bool = field( + default=False, metadata={"help": "If set to `True`, the model will be wrapped in `torch.compile`."} + ) + torch_compile_backend: Optional[str] = field( + default=None, + metadata={ + "help": "Which backend to use with `torch.compile`, passing one will trigger a model compilation.", + }, + ) + torch_compile_mode: Optional[str] = field( + default=None, + metadata={ + "help": "Which mode to use with `torch.compile`, passing one will trigger a model compilation.", + }, + ) + + include_tokens_per_second: bool = field( + default=False, + metadata={"help": "If set to `True`, the speed metrics will include `tgs` (tokens per second per device)."}, + ) + + include_num_input_tokens_seen: Union[str, bool] = field( + default=False, + metadata={ + "help": ( + "Whether to track the number of input tokens seen. " + "Can be `'all'` to count all tokens, `'non_padding'` to count only non-padding tokens, " + "or a boolean (`True` maps to `'all'`, `False` to `'no'`)." + ) + }, + ) + + neftune_noise_alpha: Optional[float] = field( + default=None, + metadata={ + "help": "Activates neftune noise embeddings into the model. NEFTune has been proven to drastically improve model performances for instruction fine-tuning. Check out the original paper here: https://huggingface.co/papers/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune. Only supported for `PreTrainedModel` and `PeftModel` classes." + }, + ) + + optim_target_modules: Union[None, str, list[str]] = field( + default=None, + metadata={ + "help": "Target modules for the optimizer defined in the `optim` argument. Only used for the GaLore optimizer at the moment." + }, + ) + + batch_eval_metrics: bool = field( + default=False, + metadata={"help": "Break eval metrics calculation into batches to save memory."}, + ) + + eval_on_start: bool = field( + default=False, + metadata={ + "help": "Whether to run through the entire `evaluation` step at the very beginning of training as a sanity check." + }, + ) + + use_liger_kernel: bool = field( + default=False, + metadata={"help": "Whether or not to enable the Liger Kernel for model training."}, + ) + + liger_kernel_config: Optional[dict[str, bool]] = field( + default=None, + metadata={ + "help": ( + "Configuration to be used for Liger Kernel. When use_liger_kernel=True, " + "this dict is passed as keyword arguments to the `_apply_liger_kernel_to_instance` function, " + "which specifies which kernels to apply. Available options vary by model " + "but typically include: 'rope', 'swiglu', 'cross_entropy', 'fused_linear_cross_entropy', " + "'rms_norm', etc. If None, use the default kernel configurations." + ) + }, + ) + + eval_use_gather_object: bool = field( + default=False, + metadata={ + "help": "Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices." + }, + ) + + average_tokens_across_devices: bool = field( + default=True, + metadata={ + "help": "Whether or not to average tokens across devices. If enabled, will use all_reduce to " + "synchronize num_tokens_in_batch for precise loss calculation. Reference: " + "https://github.com/huggingface/transformers/issues/34242" + }, + ) + + def __post_init__(self): + # Set default output_dir if not provided + if self.output_dir is None: + self.output_dir = "trainer_output" + logger.info( + "No output directory specified, defaulting to 'trainer_output'. " + "To change this behavior, specify --output_dir when creating TrainingArguments." + ) + + # Parse in args that could be `dict` sent in from the CLI as a string + for field in self._VALID_DICT_FIELDS: + passed_value = getattr(self, field) + # We only want to do this if the str starts with a bracket to indicate a `dict` + # else its likely a filename if supported + if isinstance(passed_value, str) and passed_value.startswith("{"): + loaded_dict = json.loads(passed_value) + # Convert str values to types if applicable + loaded_dict = _convert_str_dict(loaded_dict) + setattr(self, field, loaded_dict) + + # expand paths, if not os.makedirs("~/bar") will make directory + # in the current directory instead of the actual home + # see https://github.com/huggingface/transformers/issues/10628 + if self.output_dir is not None: + self.output_dir = os.path.expanduser(self.output_dir) + if self.logging_dir is None and self.output_dir is not None: + self.logging_dir = os.path.join(self.output_dir, default_logdir()) + if self.logging_dir is not None: + self.logging_dir = os.path.expanduser(self.logging_dir) + + if self.disable_tqdm is None: + self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN + + if isinstance(self.eval_strategy, EvaluationStrategy): + warnings.warn( + "using `EvaluationStrategy` for `eval_strategy` is deprecated and will be removed in version 5" + " of 🤗 Transformers. Use `IntervalStrategy` instead", + FutureWarning, + ) + # Go back to the underlying string or we won't be able to instantiate `IntervalStrategy` on it. + self.eval_strategy = self.eval_strategy.value + if self.no_cuda: + warnings.warn( + "using `no_cuda` is deprecated and will be removed in version 5.0 of 🤗 Transformers. " + "Use `use_cpu` instead", + FutureWarning, + ) + self.use_cpu = self.no_cuda + + self.eval_strategy = IntervalStrategy(self.eval_strategy) + self.logging_strategy = IntervalStrategy(self.logging_strategy) + self.save_strategy = SaveStrategy(self.save_strategy) + self.hub_strategy = HubStrategy(self.hub_strategy) + + self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type) + if self.do_eval is False and self.eval_strategy != IntervalStrategy.NO: + self.do_eval = True + + if self.torch_empty_cache_steps is not None: + if not (isinstance(self.torch_empty_cache_steps, int) and self.torch_empty_cache_steps > 0): + raise ValueError( + f"`torch_empty_cache_steps` must be an integer bigger than 0, got {self.torch_empty_cache_steps}." + ) + + # eval_steps has to be defined and non-zero, fallbacks to logging_steps if the latter is non-zero + if self.eval_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0): + if self.logging_steps > 0: + logger.info(f"using `logging_steps` to initialize `eval_steps` to {self.logging_steps}") + self.eval_steps = self.logging_steps + else: + raise ValueError( + f"evaluation strategy {self.eval_strategy} requires either non-zero --eval_steps or" + " --logging_steps" + ) + + # logging_steps must be non-zero for logging_strategy that is other than 'no' + if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps == 0: + raise ValueError(f"logging strategy {self.logging_strategy} requires non-zero --logging_steps") + + if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps > 1: + if self.logging_steps != int(self.logging_steps): + raise ValueError(f"--logging_steps must be an integer if bigger than 1: {self.logging_steps}") + self.logging_steps = int(self.logging_steps) + if self.eval_strategy == IntervalStrategy.STEPS and self.eval_steps > 1: + if self.eval_steps != int(self.eval_steps): + raise ValueError(f"--eval_steps must be an integer if bigger than 1: {self.eval_steps}") + self.eval_steps = int(self.eval_steps) + if self.save_strategy == SaveStrategy.STEPS and self.save_steps > 1: + if self.save_steps != int(self.save_steps): + raise ValueError(f"--save_steps must be an integer if bigger than 1: {self.save_steps}") + self.save_steps = int(self.save_steps) + + # Sanity checks for load_best_model_at_end: we require save and eval strategies to be compatible. + if self.load_best_model_at_end and self.save_strategy != SaveStrategy.BEST: + if self.eval_strategy != self.save_strategy: + raise ValueError( + "--load_best_model_at_end requires the save and eval strategy to match, but found\n- Evaluation " + f"strategy: {self.eval_strategy}\n- Save strategy: {self.save_strategy}" + ) + if self.eval_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0: + if self.eval_steps < 1 or self.save_steps < 1: + if not (self.eval_steps < 1 and self.save_steps < 1): + raise ValueError( + "--load_best_model_at_end requires the saving steps to be a multiple of the evaluation " + "steps, which cannot get guaranteed when mixing ratio and absolute steps for save_steps " + f"{self.save_steps} and eval_steps {self.eval_steps}." + ) + # Work around floating point precision issues + LARGE_MULTIPLIER = 1_000_000 + if (self.save_steps * LARGE_MULTIPLIER) % (self.eval_steps * LARGE_MULTIPLIER) != 0: + raise ValueError( + "--load_best_model_at_end requires the saving steps to be a multiple of the evaluation " + f"steps, but found {self.save_steps}, which is not a multiple of {self.eval_steps}." + ) + else: + raise ValueError( + "--load_best_model_at_end requires the saving steps to be a round multiple of the evaluation " + f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}." + ) + + if not self.save_safetensors: + logger.info( + f"Found safetensors installation, but --save_safetensors={self.save_safetensors}. " + f"Safetensors should be a preferred weights saving format due to security and performance reasons. " + f"If your model cannot be saved by safetensors please feel free to open an issue at " + f"https://github.com/huggingface/safetensors!" + ) + + if ( + self.load_best_model_at_end or self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU + ) and self.metric_for_best_model is None: + self.metric_for_best_model = "loss" + if self.greater_is_better is None and self.metric_for_best_model is not None: + self.greater_is_better = not self.metric_for_best_model.endswith("loss") + if self.framework == "pt" and is_torch_available(): + if self.fp16_backend and self.fp16_backend != "auto": + warnings.warn( + "`fp16_backend` is deprecated and will be removed in version 5 of 🤗 Transformers. Use" + " `half_precision_backend` instead", + FutureWarning, + ) + self.half_precision_backend = self.fp16_backend + + if self.bf16 or self.bf16_full_eval: + if self.use_cpu and not is_torch_available() and not is_torch_xla_available(): + # cpu + raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10") + elif not self.use_cpu: + if not is_torch_bf16_gpu_available() and not is_torch_xla_available(): # added for tpu support + error_message = "Your setup doesn't support bf16/gpu." + if is_torch_cuda_available(): + error_message += " You need Ampere+ GPU with cuda>=11.0" + # gpu + raise ValueError(error_message) + + if self.fp16 and self.bf16: + raise ValueError("At most one of fp16 and bf16 can be True, but not both") + + if self.fp16_full_eval and self.bf16_full_eval: + raise ValueError("At most one of fp16 and bf16 can be True for full eval, but not both") + + if self.bf16: + if self.half_precision_backend == "apex": + raise ValueError(" `--half_precision_backend apex`: GPU bf16 is not supported by apex.") + + if self.half_precision_backend == "apex": + if not is_apex_available(): + raise ImportError( + "Using FP16 with APEX but APEX is not installed, please refer to" + " https://www.github.com/nvidia/apex." + ) + try: + from apex import amp # noqa: F401 + except ImportError as e: + raise ImportError( + f"apex.amp is deprecated in the latest version of apex, causing this error {e}. Either revert to an older version or use pytorch amp by setting half_precision_backend='auto' instead. See https://github.com/NVIDIA/apex/pull/1896 " + ) + + if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU: + if self.eval_strategy == IntervalStrategy.NO: + raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy") + if not is_torch_available(): + raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires torch>=0.2.0") + + self.optim = OptimizerNames(self.optim) + if self.adafactor: + warnings.warn( + "`--adafactor` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--optim" + " adafactor` instead", + FutureWarning, + ) + self.optim = OptimizerNames.ADAFACTOR + + # We need to setup the accelerator config here *before* the first call to `self.device` + if is_accelerate_available(): + if not isinstance(self.accelerator_config, AcceleratorConfig): + if self.accelerator_config is None: + self.accelerator_config = AcceleratorConfig() + elif isinstance(self.accelerator_config, dict): + self.accelerator_config = AcceleratorConfig(**self.accelerator_config) + # Check that a user didn't pass in the class instantiator + # such as `accelerator_config = AcceleratorConfig` + elif isinstance(self.accelerator_config, type): + raise NotImplementedError( + "Tried passing in a callable to `accelerator_config`, but this is not supported. " + "Please pass in a fully constructed `AcceleratorConfig` object instead." + ) + else: + self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config) + if self.accelerator_config.split_batches: + logger.info( + "Using `split_batches=True` in `accelerator_config` will override the `per_device_train_batch_size` " + "Batches will be split across all processes equally when using `split_batches=True`." + ) + + # Initialize device before we proceed + if self.framework == "pt" and is_torch_available(): + self.device + + if self.torchdynamo is not None: + warnings.warn( + "`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use" + " `torch_compile_backend` instead", + FutureWarning, + ) + self.torch_compile_backend = self.torchdynamo + if (self.torch_compile_mode is not None or self.torch_compile_backend is not None) and not self.torch_compile: + self.torch_compile = True + if self.torch_compile and self.torch_compile_backend is None: + if not self.use_cpu and is_torch_hpu_available(): + self.torch_compile_backend = "hpu_backend" + else: + self.torch_compile_backend = "inductor" + + # accelerate integration for torch compile + if self.torch_compile: + # set env vars for accelerate + prefix = "ACCELERATE_DYNAMO_" + os.environ[prefix + "BACKEND"] = self.torch_compile_backend + if self.torch_compile_mode is not None: + os.environ[prefix + "MODE"] = self.torch_compile_mode + + if self.framework == "pt" and is_torch_available() and self.torch_compile: + if is_torch_tf32_available(): + if self.tf32 is None and not self.fp16 or self.bf16: + device_str = "MUSA" if is_torch_musa_available() else "CUDA" + logger.info( + f"Setting TF32 in {device_str} backends to speedup torch compile, you won't see any improvement" + " otherwise." + ) + if is_torch_musa_available(): + torch.backends.mudnn.allow_tf32 = True + else: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + else: + logger.warning( + "The speedups for torchdynamo mostly come with GPU Ampere or higher and which is not detected here." + ) + if self.framework == "pt" and is_torch_available() and self.tf32 is not None: + if self.tf32: + if is_torch_tf32_available(): + if is_torch_musa_available(): + torch.backends.mudnn.allow_tf32 = True + else: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + else: + raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7") + else: + if is_torch_tf32_available(): + if is_torch_musa_available(): + torch.backends.mudnn.allow_tf32 = False + else: + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + # no need to assert on else + + # if training args is specified, it will override the one specified in the accelerate config + mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no") + if self.fp16: + mixed_precision_dtype = "fp16" + elif self.bf16: + mixed_precision_dtype = "bf16" + os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype + + if self.report_to is None: + logger.info( + "The default value for the training argument `--report_to` will change in v5 (from all installed " + "integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as " + "now. You should start updating your code and make this info disappear :-)." + ) + self.report_to = "all" + if self.report_to == "all" or self.report_to == ["all"]: + # Import at runtime to avoid a circular import. + from .integrations import get_available_reporting_integrations + + self.report_to = get_available_reporting_integrations() + + if "codecarbon" in self.report_to and torch.version.hip: + logger.warning( + "When using the Trainer, CodeCarbonCallback requires the `codecarbon` package, which is not compatible with AMD ROCm (https://github.com/mlco2/codecarbon/pull/490). Automatically disabling the codecarbon callback. Reference: https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments.report_to." + ) + self.report_to.remove("codecarbon") + + elif self.report_to == "none" or self.report_to == ["none"]: + self.report_to = [] + elif not isinstance(self.report_to, list): + self.report_to = [self.report_to] + + if self.warmup_ratio < 0 or self.warmup_ratio > 1: + raise ValueError("warmup_ratio must lie in range [0,1]") + elif self.warmup_ratio > 0 and self.warmup_steps > 0: + logger.info( + "Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio" + " during training" + ) + + if not isinstance(self.warmup_steps, int) or self.warmup_steps < 0: + raise ValueError("warmup_steps must be of type int and must be 0 or a positive integer.") + + if self.fsdp is None: + self.fsdp = [] + elif self.fsdp is True: + self.fsdp = [FSDPOption.FULL_SHARD] + elif isinstance(self.fsdp, str): + self.fsdp = [FSDPOption(s) for s in self.fsdp.split()] + + if self.fsdp == [FSDPOption.OFFLOAD]: + raise ValueError( + "`--fsdp offload` can't work on its own. It needs to be added to `--fsdp full_shard` or " + '`--fsdp shard_grad_op`. For example, `--fsdp "full_shard offload"`.' + ) + elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp: + raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.") + + if self.gradient_checkpointing and ( + FSDPOption.FULL_SHARD in self.fsdp or FSDPOption.HYBRID_SHARD in self.fsdp + ): + logger.warning( + "When using FSDP full shard, instead of using `gradient_checkpointing` in TrainingArguments, please" + " use `activation_checkpointing` in `fsdp_config`. The former introduces a redundant AllGather" + " operation in backward pass. Reference: https://github.com/huggingface/transformers/issues/30404" + ) + + if self.fsdp_config is None: + self.fsdp_config = {} + + if isinstance(self.fsdp_config, str): + if len(self.fsdp) == 0: + warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.") + with open(self.fsdp_config, encoding="utf-8") as f: + self.fsdp_config = json.load(f) + + if self.fsdp_config is not None and isinstance(self.fsdp_config, dict): + for k in list(self.fsdp_config.keys()): + if k.startswith("fsdp_"): + v = self.fsdp_config.pop(k) + self.fsdp_config[k[5:]] = v + + if self.fsdp_min_num_params > 0: + warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning) + + self.fsdp_config["min_num_params"] = max(self.fsdp_config.get("min_num_params", 0), self.fsdp_min_num_params) + + # if fsdp_config["transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object + if isinstance(self.fsdp_config.get("transformer_layer_cls_to_wrap", None), str): + self.fsdp_config["transformer_layer_cls_to_wrap"] = [self.fsdp_config["transformer_layer_cls_to_wrap"]] + + if self.fsdp_transformer_layer_cls_to_wrap is not None: + warnings.warn( + "using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead ", FutureWarning + ) + self.fsdp_config["transformer_layer_cls_to_wrap"] = self.fsdp_config.get( + "transformer_layer_cls_to_wrap", [] + ) + [self.fsdp_transformer_layer_cls_to_wrap] + + if len(self.fsdp) == 0 and self.fsdp_config["min_num_params"] > 0: + warnings.warn("`min_num_params` is useful only when `--fsdp` is specified.") + + if len(self.fsdp) == 0 and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None: + warnings.warn("`transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.") + + if ( + len(self.fsdp) > 0 + and self.fsdp_config["min_num_params"] > 0 + and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None + ): + raise ValueError("`min_num_params` and `transformer_layer_cls_to_wrap` are mutually exclusive.") + self.fsdp_config["xla"] = self.fsdp_config.get("xla", False) + self.fsdp_config["xla_fsdp_v2"] = self.fsdp_config.get("xla_fsdp_v2", False) + self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False) + if self.fsdp_config["xla"]: + if len(self.fsdp) > 0: + # store XLA fsdp configuration parameters into a dictionary + # Copy the config to avoid modifying the original config (which may be used for JSON serialization) + self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {}).copy() + # apply appropriate string to torch.dtype conversions for parameters + if "compute_dtype" in self.xla_fsdp_config: + self.xla_fsdp_config["compute_dtype"] = getattr(torch, self.xla_fsdp_config["compute_dtype"]) + if "buffer_dtype" in self.xla_fsdp_config: + self.xla_fsdp_config["buffer_dtype"] = getattr(torch, self.xla_fsdp_config["buffer_dtype"]) + else: + warnings.warn("XLA FSDP can be used only when `--fsdp` is specified.") + else: + if self.fsdp_config["xla_fsdp_grad_ckpt"]: + warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.") + + # accelerate integration for FSDP + if len(self.fsdp) > 0 and not self.fsdp_config["xla"]: + os.environ["ACCELERATE_USE_FSDP"] = "true" + from accelerate.utils.constants import ( + FSDP_AUTO_WRAP_POLICY, + FSDP_SHARDING_STRATEGY, + ) + + prefix = "FSDP_" + for fsdp_option in self.fsdp: + if fsdp_option.upper() in FSDP_SHARDING_STRATEGY: + # set environment variable for FSDP sharding strategy + os.environ[f"{prefix}SHARDING_STRATEGY"] = str( + FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1 + ) + elif fsdp_option == FSDPOption.OFFLOAD: + os.environ[f"{prefix}OFFLOAD_PARAMS"] = "true" + elif fsdp_option == FSDPOption.AUTO_WRAP: + os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0] + if self.fsdp_config["min_num_params"] > 0: + os.environ[f"{prefix}MIN_NUM_PARAMS"] = str(self.fsdp_config["min_num_params"]) + os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1] + elif self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None: + os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"] = ",".join( + self.fsdp_config["transformer_layer_cls_to_wrap"] + ) + prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH") + os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper() + os.environ[f"{prefix}FORWARD_PREFETCH"] = str(self.fsdp_config.get("forward_prefetch", "false")).lower() + + sync_module_states = str(self.fsdp_config.get("sync_module_states", "true")).lower() + cpu_ram_efficient_loading = str(self.fsdp_config.get("cpu_ram_efficient_loading", "false")).lower() + + if sync_module_states == "false" and cpu_ram_efficient_loading == "true": + # In this case, all the processes except the main process would have random weights leading + # to unexpected behaviour during training, thus throwing error here to prevent it. + raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`') + + os.environ[f"{prefix}SYNC_MODULE_STATES"] = sync_module_states + os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading + + os.environ[f"{prefix}USE_ORIG_PARAMS"] = str(self.fsdp_config.get("use_orig_params", "true")).lower() + + if self.tpu_metrics_debug: + warnings.warn( + "using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use" + " `--debug tpu_metrics_debug` instead", + FutureWarning, + ) + if self.debug is None: + self.debug = " tpu_metrics_debug" + else: + self.debug += " tpu_metrics_debug" + self.tpu_metrics_debug = False + + if isinstance(self.debug, str): + self.debug = [DebugOption(s) for s in self.debug.split()] + elif self.debug is None: + self.debug = [] + + self.deepspeed_plugin = None + if self.deepspeed: + # - must be run very last in arg parsing, since it will use a lot of these settings. + # - must be run before the model is created. + if not is_accelerate_available(): + raise ValueError( + f"--deepspeed requires Accelerate to be installed: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`." + ) + from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig + + # will be used later by the Trainer + # note: leave self.deepspeed unmodified in case a user relies on it not to be modified) + self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed) + self.hf_deepspeed_config.trainer_config_process(self) + + # Accelerate DeepSpeed Plugin + from accelerate.utils import DeepSpeedPlugin + + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + self.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config) + elif strtobool(os.environ.get("ACCELERATE_USE_DEEPSPEED", "false")): + # Accelerate DeepSpeed Plugin + from accelerate.utils import DeepSpeedPlugin + + self.deepspeed_plugin = DeepSpeedPlugin() + mixed_precision = os.environ.get("ACCELERATE_MIXED_PRECISION", "no") + self.deepspeed_plugin.set_mixed_precision(mixed_precision) + self.deepspeed_plugin.set_deepspeed_weakref() + + # Set mixed precision environment variable after DeepSpeed processing + # This ensures DeepSpeed config overrides have been applied to fp16/bf16 settings + if self.half_precision_backend != "apex": + mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no") + if self.fp16: + mixed_precision_dtype = "fp16" + elif self.bf16: + mixed_precision_dtype = "bf16" + os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype + + if self.use_cpu: + self.dataloader_pin_memory = False + + if self.dataloader_num_workers == 0 and self.dataloader_prefetch_factor is not None: + raise ValueError( + "--dataloader_prefetch_factor can only be set when data is loaded in a different process, i.e." + " when --dataloader_num_workers > 1." + ) + + if self.push_to_hub_token is not None: + warnings.warn( + "`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use " + "`--hub_token` instead.", + FutureWarning, + ) + self.hub_token = self.push_to_hub_token + + if self.push_to_hub_model_id is not None: + self.hub_model_id = get_full_repo_name( + self.push_to_hub_model_id, organization=self.push_to_hub_organization, token=self.hub_token + ) + if self.push_to_hub_organization is not None: + warnings.warn( + "`--push_to_hub_model_id` and `--push_to_hub_organization` are deprecated and will be removed in " + "version 5 of 🤗 Transformers. Use `--hub_model_id` instead and pass the full repo name to this " + f"argument (in this case {self.hub_model_id}).", + FutureWarning, + ) + else: + warnings.warn( + "`--push_to_hub_model_id` is deprecated and will be removed in version 5 of 🤗 Transformers. Use " + "`--hub_model_id` instead and pass the full repo name to this argument (in this case " + f"{self.hub_model_id}).", + FutureWarning, + ) + elif self.push_to_hub_organization is not None: + self.hub_model_id = f"{self.push_to_hub_organization}/{Path(self.output_dir).name}" + warnings.warn( + "`--push_to_hub_organization` is deprecated and will be removed in version 5 of 🤗 Transformers. Use " + "`--hub_model_id` instead and pass the full repo name to this argument (in this case " + f"{self.hub_model_id}).", + FutureWarning, + ) + + if self.eval_use_gather_object and not is_accelerate_available("0.30.0"): + raise ValueError( + "--eval_use_gather_object requires Accelerate to be version of `accelerate` > 0.30.0." + "This is not supported and we recommend you to update your version." + ) + + if self.data_seed is not None: + if not is_accelerate_available("1.1.0"): + raise NotImplementedError( + "data_seed requires Accelerate version `accelerate` >= 1.1.0. " + "This is not supported and we recommend you to update your version." + ) + + if self.include_inputs_for_metrics: + logger.warning( + "Using `include_inputs_for_metrics` is deprecated and will be removed in version 5 of 🤗 Transformers. Please use `include_for_metrics` list argument instead." + ) + self.include_for_metrics.append("inputs") + + if self.include_num_input_tokens_seen is True: + self.include_num_input_tokens_seen = "all" + elif self.include_num_input_tokens_seen is False: + self.include_num_input_tokens_seen = "no" + + def __str__(self): + self_as_dict = asdict(self) + + # Remove deprecated arguments. That code should be removed once + # those deprecated arguments are removed from TrainingArguments. (TODO: v5) + del self_as_dict["per_gpu_train_batch_size"] + del self_as_dict["per_gpu_eval_batch_size"] + + self_as_dict = {k: f"<{k.upper()}>" if k.endswith("_token") else v for k, v in self_as_dict.items()} + + attrs_as_str = [f"{k}={v},\n" for k, v in sorted(self_as_dict.items())] + return f"{self.__class__.__name__}(\n{''.join(attrs_as_str)})" + + __repr__ = __str__ + + @property + def train_batch_size(self) -> int: + """ + The actual batch size for training (may differ from `per_gpu_train_batch_size` in distributed training). + """ + if self.per_gpu_train_batch_size: + logger.warning( + "Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future " + "version. Using `--per_device_train_batch_size` is preferred." + ) + per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size + train_batch_size = per_device_batch_size * max(1, self.n_gpu) + return train_batch_size + + @property + def eval_batch_size(self) -> int: + """ + The actual batch size for evaluation (may differ from `per_gpu_eval_batch_size` in distributed training). + """ + if self.per_gpu_eval_batch_size: + logger.warning( + "Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future " + "version. Using `--per_device_eval_batch_size` is preferred." + ) + per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size + eval_batch_size = per_device_batch_size * max(1, self.n_gpu) + return eval_batch_size + + @property + def ddp_timeout_delta(self) -> timedelta: + """ + The actual timeout for torch.distributed.init_process_group since it expects a timedelta variable. + """ + return timedelta(seconds=self.ddp_timeout) + + @cached_property + def _setup_devices(self) -> "torch.device": + requires_backends(self, ["torch"]) + logger.info("PyTorch: setting up devices") + if not is_sagemaker_mp_enabled(): + if not is_accelerate_available(): + raise ImportError( + f"Using the `Trainer` with `PyTorch` requires `accelerate>={ACCELERATE_MIN_VERSION}`: " + f"Please run `pip install transformers[torch]` or `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" + ) + # We delay the init of `PartialState` to the end for clarity + accelerator_state_kwargs: dict[str, Any] = {"enabled": True, "use_configured_state": False} + if isinstance(self.accelerator_config, AcceleratorConfig): + accelerator_state_kwargs["use_configured_state"] = self.accelerator_config.pop( + "use_configured_state", False + ) + if accelerator_state_kwargs["use_configured_state"]: + if PartialState._shared_state == {}: + raise ValueError( + "Passing `'use_configured_state':True` to the AcceleratorConfig requires a pre-configured " + "`AcceleratorState` or `PartialState` to be defined before calling `TrainingArguments`. " + ) + # We rely on `PartialState` to yell if there's issues here (which it will) + self.distributed_state = PartialState(cpu=self.use_cpu) + if self.deepspeed and self.distributed_state.distributed_type != DistributedType.DEEPSPEED: + raise RuntimeError( + "Tried to use an already configured `Accelerator` or `PartialState` that was not initialized for DeepSpeed, " + "but also passed in a `deepspeed` configuration to the `TrainingArguments`. Please set " + "`use_configured_state:False` instead or setup your `Accelerator` or `PartialState` properly." + ) + else: + AcceleratorState._reset_state(reset_partial_state=True) + self.distributed_state = None + if "ACCELERATE_USE_IPEX" not in os.environ: + os.environ["ACCELERATE_USE_IPEX"] = "false" + + self._n_gpu = 1 + if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")): + accelerator_state_kwargs["cpu"] = True + accelerator_state_kwargs["backend"] = self.ddp_backend + self._n_gpu = 0 + elif is_sagemaker_mp_enabled(): + accelerator_state_kwargs["enabled"] = False + local_rank = smp.local_rank() + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + elif is_sagemaker_dp_enabled(): + accelerator_state_kwargs["_use_sagemaker_dp"] = True + elif self.deepspeed: + accelerator_state_kwargs["use_deepspeed"] = True + accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout) + else: + accelerator_state_kwargs["backend"] = self.ddp_backend + accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout) + + # Now we pop everything + if accelerator_state_kwargs.pop("enabled", False) and not accelerator_state_kwargs.pop( + "use_configured_state", False + ): + # We need to patch this env var when enabling to detect deepspeed + use_deepspeed = accelerator_state_kwargs.pop("use_deepspeed", False) + if use_deepspeed: + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + self.distributed_state = PartialState(**accelerator_state_kwargs) + if use_deepspeed: + del os.environ["ACCELERATE_USE_DEEPSPEED"] + if not is_sagemaker_mp_enabled(): + device = self.distributed_state.device + self.local_rank = self.distributed_state.local_process_index + if dist.is_available() and dist.is_initialized() and self.parallel_mode != ParallelMode.DISTRIBUTED: + logger.warning( + "torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. " + "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch" + ) + if is_torch_xla_available(): + device = self.distributed_state.device + self._n_gpu = 0 + elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled(): + # Already set _n_gpu + pass + elif self.distributed_state.distributed_type == DistributedType.NO: + if self.use_mps_device: + warnings.warn( + "`use_mps_device` is deprecated and will be removed in version 5.0 of 🤗 Transformers. " + "`mps` device will be used by default if available similar to the way `cuda` device is used." + "Therefore, no action from user is required. " + ) + if device.type != "mps": + raise ValueError( + "Either you do not have an MPS-enabled device on this machine or MacOS version is not 12.3+ " + "or current PyTorch install was not built with MPS enabled." + ) + if self.use_cpu: + device = torch.device("cpu") + elif is_torch_mps_available(): + device = torch.device("mps") + elif is_torch_xpu_available(): + if not is_ipex_available() and not is_accelerate_available("0.32.0.dev"): + raise ImportError("Using the XPU PyTorch backend requires `accelerate>=0.32.0.dev`") + device = torch.device("xpu:0") + torch.xpu.set_device(device) + elif is_torch_mlu_available(): + device = torch.device("mlu:0") + torch.mlu.set_device(device) + elif is_torch_musa_available(): + device = torch.device("musa:0") + torch.musa.set_device(device) + elif is_torch_npu_available(): + device = torch.device("npu:0") + torch.npu.set_device(device) + elif is_torch_hpu_available(): + device = torch.device("hpu:0") + torch.hpu.set_device(device) + else: + # if n_gpu is > 1 we'll use nn.DataParallel. + # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` + # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will + # trigger an error that a device index is missing. Index 0 takes into account the + # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` + # will use the first GPU in that env, i.e. GPU#1 + device = torch.device( + "cuda:0" if torch.cuda.is_available() else os.environ.get("ACCELERATE_TORCH_DEVICE", "cpu") + ) + # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at + # the default value. + self._n_gpu = torch.cuda.device_count() + if device.type == "cuda": + torch.cuda.set_device(device) + return device + + @property + def device(self) -> "torch.device": + """ + The device used by this process. + """ + requires_backends(self, ["torch"]) + return self._setup_devices + + @property + def n_gpu(self): + """ + The number of GPUs used by this process. + + Note: + This will only be greater than one when you have multiple GPUs available but are not using distributed + training. For distributed training, it will always be 1. + """ + requires_backends(self, ["torch"]) + # Make sure `self._n_gpu` is properly setup. + if not hasattr(self, "_n_gpu"): + _ = self._setup_devices + return self._n_gpu + + @property + def parallel_mode(self): + """ + The current mode used for parallelism if multiple GPUs/TPU cores are available. One of: + + - `ParallelMode.NOT_PARALLEL`: no parallelism (CPU or one GPU). + - `ParallelMode.NOT_DISTRIBUTED`: several GPUs in one single process (uses `torch.nn.DataParallel`). + - `ParallelMode.DISTRIBUTED`: several GPUs, each having its own process (uses + `torch.nn.DistributedDataParallel`). + - `ParallelMode.TPU`: several TPU cores. + """ + requires_backends(self, ["torch"]) + if is_torch_xla_available(): + return ParallelMode.TPU + elif is_sagemaker_mp_enabled(): + return ParallelMode.SAGEMAKER_MODEL_PARALLEL + elif is_sagemaker_dp_enabled(): + return ParallelMode.SAGEMAKER_DATA_PARALLEL + elif ( + self.distributed_state is not None and self.distributed_state.distributed_type != DistributedType.NO + ) or (self.distributed_state is None and self.local_rank != -1): + return ParallelMode.DISTRIBUTED + elif self.n_gpu > 1: + return ParallelMode.NOT_DISTRIBUTED + else: + return ParallelMode.NOT_PARALLEL + + @property + def world_size(self): + """ + The number of processes used in parallel. + """ + requires_backends(self, ["torch"]) + if self.distributed_state is not None: + return self.distributed_state.num_processes + elif is_sagemaker_mp_enabled(): + return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size() + return 1 + + @property + def process_index(self): + """ + The index of the current process used. + """ + requires_backends(self, ["torch"]) + if self.distributed_state is not None: + return self.distributed_state.process_index + elif is_sagemaker_mp_enabled(): + return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank() + return 0 + + @property + def local_process_index(self): + """ + The index of the local process used. + """ + requires_backends(self, ["torch"]) + + if self.distributed_state is not None: + return self.distributed_state.local_process_index + elif is_sagemaker_mp_enabled(): + return smp.local_rank() + return 0 + + @property + def should_log(self): + """ + Whether or not the current process should produce log. + """ + if self.log_on_each_node: + return self.local_process_index == 0 + else: + if is_sagemaker_mp_enabled(): + return smp.rank() == 0 + else: + return self.process_index == 0 + + @property + def should_save(self): + """ + Whether or not the current process should write to disk, e.g., to save models and checkpoints. + """ + if self.save_on_each_node: + return self.local_process_index == 0 + else: + if is_sagemaker_mp_enabled(): + return smp.rank() == 0 + else: + return self.process_index == 0 + + def get_process_log_level(self): + """ + Returns the log level to be used depending on whether this process is the main process of node 0, main process + of node non-0, or a non-main process. + + For the main process the log level defaults to the logging level set (`logging.WARNING` if you didn't do + anything) unless overridden by `log_level` argument. + + For the replica processes the log level defaults to `logging.WARNING` unless overridden by `log_level_replica` + argument. + + The choice between the main and replica process settings is made according to the return value of `should_log`. + """ + + # convert to int + log_level = trainer_log_levels[self.log_level] + log_level_replica = trainer_log_levels[self.log_level_replica] + + log_level_main_node = logging.get_verbosity() if log_level == -1 else log_level + log_level_replica_node = logging.get_verbosity() if log_level_replica == -1 else log_level_replica + return log_level_main_node if self.should_log else log_level_replica_node + + @property + def place_model_on_device(self): + """ + Can be subclassed and overridden for some specific integrations. + """ + return not is_sagemaker_mp_enabled() + + @property + def _no_sync_in_gradient_accumulation(self): + """ + Whether or not to use no_sync for the gradients when doing gradient accumulation. + """ + return not ( + self.deepspeed or is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled() or is_torch_neuroncore_available() + ) + + @contextlib.contextmanager + def main_process_first(self, local=True, desc="work"): + """ + A context manager for torch distributed environment where on needs to do something on the main process, while + blocking replicas, and when it's finished releasing the replicas. + + One such use is for `datasets`'s `map` feature which to be efficient should be run once on the main process, + which upon completion saves a cached version of results and which then automatically gets loaded by the + replicas. + + Args: + local (`bool`, *optional*, defaults to `True`): + if `True` first means process of rank 0 of each node if `False` first means process of rank 0 of node + rank 0 In multi-node environment with a shared filesystem you most likely will want to use + `local=False` so that only the main process of the first node will do the processing. If however, the + filesystem is not shared, then the main process of each node will need to do the processing, which is + the default behavior. + desc (`str`, *optional*, defaults to `"work"`): + a work description to be used in debug logs + + """ + if is_torch_available() and self.world_size > 1: + main_process_desc = "main local process" if local else "main process" + if self.distributed_state is not None: + is_main_process = ( + self.distributed_state.is_local_main_process if local else self.distributed_state.is_main_process + ) + elif is_sagemaker_mp_enabled(): + is_main_process = smp.rank() == 0 + + try: + if not is_main_process: + # tell all replicas to wait + logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}") + + if is_torch_xla_available(): + xm.rendezvous(desc) + else: + dist.barrier() + yield + finally: + if is_main_process: + # the wait is over + logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas") + if is_torch_xla_available(): + xm.rendezvous(desc) + else: + dist.barrier() + else: + yield + + def get_warmup_steps(self, num_training_steps: int): + """ + Get number of steps used for a linear warmup. + """ + warmup_steps = ( + self.warmup_steps if self.warmup_steps > 0 else math.ceil(num_training_steps * self.warmup_ratio) + ) + return warmup_steps + + def _dict_dtype_to_str(self, d: dict[str, Any]) -> None: + """ + Checks whether the passed dictionary and its nested dicts have a *dtype* key and if it's not None, + converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"* + string, which can then be stored in the json format. + """ + if d.get("dtype") is not None and not isinstance(d["dtype"], str): + d["dtype"] = str(d["dtype"]).split(".")[1] + for value in d.values(): + if isinstance(value, dict): + self._dict_dtype_to_str(value) + + def to_dict(self): + """ + Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates + the token values by removing their value. + """ + # filter out fields that are defined as field(init=False) + d = {field.name: getattr(self, field.name) for field in fields(self) if field.init} + + for k, v in d.items(): + if isinstance(v, Enum): + d[k] = v.value + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): + d[k] = [x.value for x in v] + if k.endswith("_token"): + d[k] = f"<{k.upper()}>" + # Handle the accelerator_config if passed + if is_accelerate_available() and isinstance(v, AcceleratorConfig): + d[k] = v.to_dict() + # Handle the quantization_config if passed + if k == "model_init_kwargs" and isinstance(v, dict) and "quantization_config" in v: + quantization_config = v.get("quantization_config") + if quantization_config and not isinstance(quantization_config, dict): + d[k]["quantization_config"] = quantization_config.to_dict() + if k == "parallelism_config" and v is not None: + d[k] = v.to_json() + + self._dict_dtype_to_str(d) + + return d + + def to_json_string(self): + """ + Serializes this instance to a JSON string. + """ + return json.dumps(self.to_dict(), indent=2) + + def to_sanitized_dict(self) -> dict[str, Any]: + """ + Sanitized serialization to use with TensorBoard's hparams + """ + d = self.to_dict() + d = {**d, **{"train_batch_size": self.train_batch_size, "eval_batch_size": self.eval_batch_size}} + + valid_types = [bool, int, float, str] + if is_torch_available(): + valid_types.append(torch.Tensor) + + return {k: v if type(v) in valid_types else str(v) for k, v in d.items()} + + # The following methods are there to simplify the instantiation of `TrainingArguments` + def set_training( + self, + learning_rate: float = 5e-5, + batch_size: int = 8, + weight_decay: float = 0, + num_epochs: float = 3, + max_steps: int = -1, + gradient_accumulation_steps: int = 1, + seed: int = 42, + gradient_checkpointing: bool = False, + ): + """ + A method that regroups all basic arguments linked to the training. + + + + Calling this method will automatically set `self.do_train` to `True`. + + + + Args: + learning_rate (`float`, *optional*, defaults to 5e-5): + The initial learning rate for the optimizer. + batch_size (`int` *optional*, defaults to 8): + The batch size per device (GPU/TPU core/CPU...) used for training. + weight_decay (`float`, *optional*, defaults to 0): + The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights in the + optimizer. + num_train_epochs(`float`, *optional*, defaults to 3.0): + Total number of training epochs to perform (if not an integer, will perform the decimal part percents + of the last epoch before stopping training). + max_steps (`int`, *optional*, defaults to -1): + If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`. + For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until + `max_steps` is reached. + gradient_accumulation_steps (`int`, *optional*, defaults to 1): + Number of updates steps to accumulate the gradients for, before performing a backward/update pass. + + + + When using gradient accumulation, one step is counted as one step with backward pass. Therefore, + logging, evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training + examples. + + + + seed (`int`, *optional*, defaults to 42): + Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use + the [`~Trainer.model_init`] function to instantiate the model if it has some randomly initialized + parameters. + gradient_checkpointing (`bool`, *optional*, defaults to `False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_training(learning_rate=1e-4, batch_size=32) + >>> args.learning_rate + 1e-4 + ``` + """ + self.do_train = True + self.learning_rate = learning_rate + self.per_device_train_batch_size = batch_size + self.weight_decay = weight_decay + self.num_train_epochs = num_epochs + self.max_steps = max_steps + self.gradient_accumulation_steps = gradient_accumulation_steps + self.seed = seed + self.gradient_checkpointing = gradient_checkpointing + return self + + def set_evaluate( + self, + strategy: Union[str, IntervalStrategy] = "no", + steps: int = 500, + batch_size: int = 8, + accumulation_steps: Optional[int] = None, + delay: Optional[float] = None, + loss_only: bool = False, + jit_mode: bool = False, + ): + """ + A method that regroups all arguments linked to evaluation. + + Args: + strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`): + The evaluation strategy to adopt during training. Possible values are: + + - `"no"`: No evaluation is done during training. + - `"steps"`: Evaluation is done (and logged) every `steps`. + - `"epoch"`: Evaluation is done at the end of each epoch. + + Setting a `strategy` different from `"no"` will set `self.do_eval` to `True`. + steps (`int`, *optional*, defaults to 500): + Number of update steps between two evaluations if `strategy="steps"`. + batch_size (`int` *optional*, defaults to 8): + The batch size per device (GPU/TPU core/CPU...) used for evaluation. + accumulation_steps (`int`, *optional*): + Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. + If left unset, the whole predictions are accumulated on GPU/TPU before being moved to the CPU (faster + but requires more memory). + delay (`float`, *optional*): + Number of epochs or steps to wait for before the first evaluation can be performed, depending on the + eval_strategy. + loss_only (`bool`, *optional*, defaults to `False`): + Ignores all outputs except the loss. + jit_mode (`bool`, *optional*): + Whether or not to use PyTorch jit trace for inference. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_evaluate(strategy="steps", steps=100) + >>> args.eval_steps + 100 + ``` + """ + self.eval_strategy = IntervalStrategy(strategy) + if self.eval_strategy == IntervalStrategy.STEPS and steps == 0: + raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.") + self.do_eval = self.eval_strategy != IntervalStrategy.NO + self.eval_steps = steps + self.per_device_eval_batch_size = batch_size + self.eval_accumulation_steps = accumulation_steps + self.eval_delay = delay + self.prediction_loss_only = loss_only + self.jit_mode_eval = jit_mode + return self + + def set_testing( + self, + batch_size: int = 8, + loss_only: bool = False, + jit_mode: bool = False, + ): + """ + A method that regroups all basic arguments linked to testing on a held-out dataset. + + + + Calling this method will automatically set `self.do_predict` to `True`. + + + + Args: + batch_size (`int` *optional*, defaults to 8): + The batch size per device (GPU/TPU core/CPU...) used for testing. + loss_only (`bool`, *optional*, defaults to `False`): + Ignores all outputs except the loss. + jit_mode (`bool`, *optional*): + Whether or not to use PyTorch jit trace for inference. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_testing(batch_size=32) + >>> args.per_device_eval_batch_size + 32 + ``` + """ + self.do_predict = True + self.per_device_eval_batch_size = batch_size + self.prediction_loss_only = loss_only + self.jit_mode_eval = jit_mode + return self + + def set_save( + self, + strategy: Union[str, IntervalStrategy] = "steps", + steps: int = 500, + total_limit: Optional[int] = None, + on_each_node: bool = False, + ): + """ + A method that regroups all arguments linked to checkpoint saving. + + Args: + strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`): + The checkpoint save strategy to adopt during training. Possible values are: + + - `"no"`: No save is done during training. + - `"epoch"`: Save is done at the end of each epoch. + - `"steps"`: Save is done every `save_steps`. + + steps (`int`, *optional*, defaults to 500): + Number of updates steps before two checkpoint saves if `strategy="steps"`. + total_limit (`int`, *optional*): + If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in + `output_dir`. + on_each_node (`bool`, *optional*, defaults to `False`): + When doing multi-node distributed training, whether to save models and checkpoints on each node, or + only on the main one. + + This should not be activated when the different nodes use the same storage as the files will be saved + with the same names for each node. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_save(strategy="steps", steps=100) + >>> args.save_steps + 100 + ``` + """ + self.save_strategy = SaveStrategy(strategy) + if self.save_strategy == SaveStrategy.STEPS and steps == 0: + raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.") + self.save_steps = steps + self.save_total_limit = total_limit + self.save_on_each_node = on_each_node + return self + + def set_logging( + self, + strategy: Union[str, IntervalStrategy] = "steps", + steps: int = 500, + report_to: Union[str, list[str]] = "none", + level: str = "passive", + first_step: bool = False, + nan_inf_filter: bool = False, + on_each_node: bool = False, + replica_level: str = "passive", + ): + """ + A method that regroups all arguments linked to logging. + + Args: + strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`): + The logging strategy to adopt during training. Possible values are: + + - `"no"`: No logging is done during training. + - `"epoch"`: Logging is done at the end of each epoch. + - `"steps"`: Logging is done every `logging_steps`. + + steps (`int`, *optional*, defaults to 500): + Number of update steps between two logs if `strategy="steps"`. + level (`str`, *optional*, defaults to `"passive"`): + Logger log level to use on the main process. Possible choices are the log levels as strings: `"debug"`, + `"info"`, `"warning"`, `"error"` and `"critical"`, plus a `"passive"` level which doesn't set anything + and lets the application set the level. + report_to (`str` or `list[str]`, *optional*, defaults to `"all"`): + The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`, + `"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`, + `"neptune"`, `"swanlab"`, `"tensorboard"`, `"trackio"` and `"wandb"`. Use `"all"` to report to all + integrations installed, `"none"` for no integrations. + first_step (`bool`, *optional*, defaults to `False`): + Whether to log and evaluate the first `global_step` or not. + nan_inf_filter (`bool`, *optional*, defaults to `True`): + Whether to filter `nan` and `inf` losses for logging. If set to `True` the loss of every step that is + `nan` or `inf` is filtered and the average loss of the current logging window is taken instead. + + + + `nan_inf_filter` only influences the logging of loss values, it does not change the behavior the + gradient is computed or applied to the model. + + + + on_each_node (`bool`, *optional*, defaults to `True`): + In multinode distributed training, whether to log using `log_level` once per node, or only on the main + node. + replica_level (`str`, *optional*, defaults to `"passive"`): + Logger log level to use on replicas. Same choices as `log_level` + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_logging(strategy="steps", steps=100) + >>> args.logging_steps + 100 + ``` + """ + self.logging_strategy = IntervalStrategy(strategy) + if self.logging_strategy == IntervalStrategy.STEPS and steps == 0: + raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.") + self.logging_steps = steps + self.report_to = report_to + self.log_level = level + self.logging_first_step = first_step + self.logging_nan_inf_filter = nan_inf_filter + self.log_on_each_node = on_each_node + self.log_level_replica = replica_level + return self + + def set_push_to_hub( + self, + model_id: str, + strategy: Union[str, HubStrategy] = "every_save", + token: Optional[str] = None, + private_repo: Optional[bool] = None, + always_push: bool = False, + revision: Optional[str] = None, + ): + """ + A method that regroups all arguments linked to synchronizing checkpoints with the Hub. + + + + Calling this method will set `self.push_to_hub` to `True`, which means the `output_dir` will begin a git + directory synced with the repo (determined by `model_id`) and the content will be pushed each time a save is + triggered (depending on your `self.save_strategy`). Calling [`~Trainer.save_model`] will also trigger a push. + + + + Args: + model_id (`str`): + The name of the repository to keep in sync with the local *output_dir*. It can be a simple model ID in + which case the model will be pushed in your namespace. Otherwise it should be the whole repository + name, for instance `"user_name/model"`, which allows you to push to an organization you are a member of + with `"organization_name/model"`. + strategy (`str` or [`~trainer_utils.HubStrategy`], *optional*, defaults to `"every_save"`): + Defines the scope of what is pushed to the Hub and when. Possible values are: + + - `"end"`: push the model, its configuration, the processing_class e.g. tokenizer (if passed along to the [`Trainer`]) and a + draft of a model card when the [`~Trainer.save_model`] method is called. + - `"every_save"`: push the model, its configuration, the processing_class e.g. tokenizer (if passed along to the [`Trainer`]) + and + a draft of a model card each time there is a model save. The pushes are asynchronous to not block + training, and in case the save are very frequent, a new push is only attempted if the previous one is + finished. A last push is made with the final model at the end of training. + - `"checkpoint"`: like `"every_save"` but the latest checkpoint is also pushed in a subfolder named + last-checkpoint, allowing you to resume training easily with + `trainer.train(resume_from_checkpoint="last-checkpoint")`. + - `"all_checkpoints"`: like `"checkpoint"` but all checkpoints are pushed like they appear in the + output + folder (so you will get one checkpoint folder per folder in your final repository) + + token (`str`, *optional*): + The token to use to push the model to the Hub. Will default to the token in the cache folder obtained + with `hf auth login`. + private_repo (`bool`, *optional*, defaults to `False`): + Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists. + always_push (`bool`, *optional*, defaults to `False`): + Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not + finished. + revision (`str`, *optional*): + The revision to use when pushing to the Hub. Can be a branch name, a tag, or a commit hash. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_push_to_hub("me/awesome-model") + >>> args.hub_model_id + 'me/awesome-model' + ``` + """ + self.push_to_hub = True + self.hub_model_id = model_id + self.hub_strategy = HubStrategy(strategy) + self.hub_token = token + self.hub_private_repo = private_repo + self.hub_always_push = always_push + self.hub_revision = revision + return self + + def set_optimizer( + self, + name: Union[str, OptimizerNames] = "adamw_torch", + learning_rate: float = 5e-5, + weight_decay: float = 0, + beta1: float = 0.9, + beta2: float = 0.999, + epsilon: float = 1e-8, + args: Optional[str] = None, + ): + """ + A method that regroups all arguments linked to the optimizer and its hyperparameters. + + Args: + name (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch"`): + The optimizer to use: `"adamw_torch"`, `"adamw_torch_fused"`, `"adamw_apex_fused"`, + `"adamw_anyprecision"` or `"adafactor"`. + learning_rate (`float`, *optional*, defaults to 5e-5): + The initial learning rate. + weight_decay (`float`, *optional*, defaults to 0): + The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights. + beta1 (`float`, *optional*, defaults to 0.9): + The beta1 hyperparameter for the adam optimizer or its variants. + beta2 (`float`, *optional*, defaults to 0.999): + The beta2 hyperparameter for the adam optimizer or its variants. + epsilon (`float`, *optional*, defaults to 1e-8): + The epsilon hyperparameter for the adam optimizer or its variants. + args (`str`, *optional*): + Optional arguments that are supplied to AnyPrecisionAdamW (only useful when + `optim="adamw_anyprecision"`). + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_optimizer(name="adamw_torch", beta1=0.8) + >>> args.optim + 'adamw_torch' + ``` + """ + self.optim = OptimizerNames(name) + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.adam_beta1 = beta1 + self.adam_beta2 = beta2 + self.adam_epsilon = epsilon + self.optim_args = args + return self + + def set_lr_scheduler( + self, + name: Union[str, SchedulerType] = "linear", + num_epochs: float = 3.0, + max_steps: int = -1, + warmup_ratio: float = 0, + warmup_steps: int = 0, + ): + """ + A method that regroups all arguments linked to the learning rate scheduler and its hyperparameters. + + Args: + name (`str` or [`SchedulerType`], *optional*, defaults to `"linear"`): + The scheduler type to use. See the documentation of [`SchedulerType`] for all possible values. + num_epochs(`float`, *optional*, defaults to 3.0): + Total number of training epochs to perform (if not an integer, will perform the decimal part percents + of the last epoch before stopping training). + max_steps (`int`, *optional*, defaults to -1): + If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`. + For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until + `max_steps` is reached. + warmup_ratio (`float`, *optional*, defaults to 0.0): + Ratio of total training steps used for a linear warmup from 0 to `learning_rate`. + warmup_steps (`int`, *optional*, defaults to 0): + Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of + `warmup_ratio`. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_lr_scheduler(name="cosine", warmup_ratio=0.05) + >>> args.warmup_ratio + 0.05 + ``` + """ + self.lr_scheduler_type = SchedulerType(name) + self.num_train_epochs = num_epochs + self.max_steps = max_steps + self.warmup_ratio = warmup_ratio + self.warmup_steps = warmup_steps + return self + + def set_dataloader( + self, + train_batch_size: int = 8, + eval_batch_size: int = 8, + drop_last: bool = False, + num_workers: int = 0, + pin_memory: bool = True, + persistent_workers: bool = False, + prefetch_factor: Optional[int] = None, + auto_find_batch_size: bool = False, + ignore_data_skip: bool = False, + sampler_seed: Optional[int] = None, + ): + """ + A method that regroups all arguments linked to the dataloaders creation. + + Args: + drop_last (`bool`, *optional*, defaults to `False`): + Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch + size) or not. + num_workers (`int`, *optional*, defaults to 0): + Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in + the main process. + pin_memory (`bool`, *optional*, defaults to `True`): + Whether you want to pin memory in data loaders or not. Will default to `True`. + persistent_workers (`bool`, *optional*, defaults to `False`): + If True, the data loader will not shut down the worker processes after a dataset has been consumed + once. This allows to maintain the workers Dataset instances alive. Can potentially speed up training, + but will increase RAM usage. Will default to `False`. + prefetch_factor (`int`, *optional*): + Number of batches loaded in advance by each worker. + 2 means there will be a total of 2 * num_workers batches prefetched across all workers. + auto_find_batch_size (`bool`, *optional*, defaults to `False`) + Whether to find a batch size that will fit into memory automatically through exponential decay, + avoiding CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`) + ignore_data_skip (`bool`, *optional*, defaults to `False`): + When resuming training, whether or not to skip the epochs and batches to get the data loading at the + same stage as in the previous training. If set to `True`, the training will begin faster (as that + skipping step can take a long time) but will not yield the same results as the interrupted training + would have. + sampler_seed (`int`, *optional*): + Random seed to be used with data samplers. If not set, random generators for data sampling will use the + same seed as `self.seed`. This can be used to ensure reproducibility of data sampling, independent of + the model seed. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_dataloader(train_batch_size=16, eval_batch_size=64) + >>> args.per_device_train_batch_size + 16 + ``` + """ + self.per_device_train_batch_size = train_batch_size + self.per_device_eval_batch_size = eval_batch_size + self.dataloader_drop_last = drop_last + self.dataloader_num_workers = num_workers + self.dataloader_pin_memory = pin_memory + self.dataloader_persistent_workers = persistent_workers + self.dataloader_prefetch_factor = prefetch_factor + self.auto_find_batch_size = auto_find_batch_size + self.ignore_data_skip = ignore_data_skip + self.data_seed = sampler_seed + return self + + +class ParallelMode(Enum): + NOT_PARALLEL = "not_parallel" + NOT_DISTRIBUTED = "not_distributed" + DISTRIBUTED = "distributed" + SAGEMAKER_MODEL_PARALLEL = "sagemaker_model_parallel" + SAGEMAKER_DATA_PARALLEL = "sagemaker_data_parallel" + TPU = "tpu" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/training_args_seq2seq.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/training_args_seq2seq.py new file mode 100644 index 0000000000000000000000000000000000000000..5342b7add3932c542e35247e52920d8fc91ed325 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/training_args_seq2seq.py @@ -0,0 +1,90 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional, Union + +from .generation.configuration_utils import GenerationConfig +from .training_args import TrainingArguments +from .utils import add_start_docstrings + + +logger = logging.getLogger(__name__) + + +@dataclass +@add_start_docstrings(TrainingArguments.__doc__) +class Seq2SeqTrainingArguments(TrainingArguments): + """ + Args: + predict_with_generate (`bool`, *optional*, defaults to `False`): + Whether to use generate to calculate generative metrics (ROUGE, BLEU). + generation_max_length (`int`, *optional*): + The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default to the + `max_length` value of the model configuration. + generation_num_beams (`int`, *optional*): + The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default to the + `num_beams` value of the model configuration. + generation_config (`str` or `Path` or [`~generation.GenerationConfig`], *optional*): + Allows to load a [`~generation.GenerationConfig`] from the `from_pretrained` method. This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a configuration file saved using the + [`~GenerationConfig.save_pretrained`] method, e.g., `./my_model_directory/`. + - a [`~generation.GenerationConfig`] object. + """ + + sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."}) + predict_with_generate: bool = field( + default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} + ) + generation_max_length: Optional[int] = field( + default=None, + metadata={ + "help": ( + "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default " + "to the `max_length` value of the model configuration." + ) + }, + ) + generation_num_beams: Optional[int] = field( + default=None, + metadata={ + "help": ( + "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default " + "to the `num_beams` value of the model configuration." + ) + }, + ) + generation_config: Optional[Union[str, Path, GenerationConfig]] = field( + default=None, + metadata={ + "help": "Model id, file path or url pointing to a GenerationConfig json file, to use during prediction." + }, + ) + + def to_dict(self): + """ + Serializes this instance while replace `Enum` by their values and `GenerationConfig` by dictionaries (for JSON + serialization support). It obfuscates the token values by removing their value. + """ + # filter out fields that are defined as field(init=False) + d = super().to_dict() + for k, v in d.items(): + if isinstance(v, GenerationConfig): + d[k] = v.to_dict() + return d diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/training_args_tf.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/training_args_tf.py new file mode 100644 index 0000000000000000000000000000000000000000..24763dabf9167896f69cd939ccec99f37d5fc0de --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/training_args_tf.py @@ -0,0 +1,300 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass, field +from functools import cached_property +from typing import Optional + +from .training_args import TrainingArguments +from .utils import is_tf_available, logging, requires_backends + + +logger = logging.get_logger(__name__) + +if is_tf_available(): + import tensorflow as tf + + from .modeling_tf_utils import keras + + +@dataclass +class TFTrainingArguments(TrainingArguments): + """ + TrainingArguments is the subset of the arguments we use in our example scripts **which relate to the training loop + itself**. + + Using [`HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + output_dir (`str`): + The output directory where the model predictions and checkpoints will be written. + overwrite_output_dir (`bool`, *optional*, defaults to `False`): + If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir` + points to a checkpoint directory. + do_train (`bool`, *optional*, defaults to `False`): + Whether to run training or not. This argument is not directly used by [`Trainer`], it's intended to be used + by your training/evaluation scripts instead. See the [example + scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details. + do_eval (`bool`, *optional*): + Whether to run evaluation on the validation set or not. Will be set to `True` if `eval_strategy` is + different from `"no"`. This argument is not directly used by [`Trainer`], it's intended to be used by your + training/evaluation scripts instead. See the [example + scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details. + do_predict (`bool`, *optional*, defaults to `False`): + Whether to run predictions on the test set or not. This argument is not directly used by [`Trainer`], it's + intended to be used by your training/evaluation scripts instead. See the [example + scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details. + eval_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`): + The evaluation strategy to adopt during training. Possible values are: + + - `"no"`: No evaluation is done during training. + - `"steps"`: Evaluation is done (and logged) every `eval_steps`. + - `"epoch"`: Evaluation is done at the end of each epoch. + + per_device_train_batch_size (`int`, *optional*, defaults to 8): + The batch size per GPU/TPU core/CPU for training. + per_device_eval_batch_size (`int`, *optional*, defaults to 8): + The batch size per GPU/TPU core/CPU for evaluation. + gradient_accumulation_steps (`int`, *optional*, defaults to 1): + Number of updates steps to accumulate the gradients for, before performing a backward/update pass. + + + + When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging, + evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training examples. + + + + learning_rate (`float`, *optional*, defaults to 5e-5): + The initial learning rate for Adam. + weight_decay (`float`, *optional*, defaults to 0): + The weight decay to apply (if not zero). + adam_beta1 (`float`, *optional*, defaults to 0.9): + The beta1 hyperparameter for the Adam optimizer. + adam_beta2 (`float`, *optional*, defaults to 0.999): + The beta2 hyperparameter for the Adam optimizer. + adam_epsilon (`float`, *optional*, defaults to 1e-8): + The epsilon hyperparameter for the Adam optimizer. + max_grad_norm (`float`, *optional*, defaults to 1.0): + Maximum gradient norm (for gradient clipping). + num_train_epochs(`float`, *optional*, defaults to 3.0): + Total number of training epochs to perform. + max_steps (`int`, *optional*, defaults to -1): + If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`. + For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until + `max_steps` is reached. + warmup_ratio (`float`, *optional*, defaults to 0.0): + Ratio of total training steps used for a linear warmup from 0 to `learning_rate`. + warmup_steps (`int`, *optional*, defaults to 0): + Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`. + logging_dir (`str`, *optional*): + [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to + *runs/**CURRENT_DATETIME_HOSTNAME***. + logging_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`): + The logging strategy to adopt during training. Possible values are: + + - `"no"`: No logging is done during training. + - `"epoch"`: Logging is done at the end of each epoch. + - `"steps"`: Logging is done every `logging_steps`. + + logging_first_step (`bool`, *optional*, defaults to `False`): + Whether to log and evaluate the first `global_step` or not. + logging_steps (`int`, *optional*, defaults to 500): + Number of update steps between two logs if `logging_strategy="steps"`. + save_strategy (`str` or [`~trainer_utils.SaveStrategy`], *optional*, defaults to `"steps"`): + The checkpoint save strategy to adopt during training. Possible values are: + + - `"no"`: No save is done during training. + - `"epoch"`: Save is done at the end of each epoch. + - `"steps"`: Save is done every `save_steps`. + + save_steps (`int`, *optional*, defaults to 500): + Number of updates steps before two checkpoint saves if `save_strategy="steps"`. + save_total_limit (`int`, *optional*): + If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in + `output_dir`. + no_cuda (`bool`, *optional*, defaults to `False`): + Whether to not use CUDA even when it is available or not. + seed (`int`, *optional*, defaults to 42): + Random seed that will be set at the beginning of training. + fp16 (`bool`, *optional*, defaults to `False`): + Whether to use 16-bit (mixed) precision training (through NVIDIA Apex) instead of 32-bit training. + fp16_opt_level (`str`, *optional*, defaults to 'O1'): + For `fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details on + the [Apex documentation](https://nvidia.github.io/apex/amp). + local_rank (`int`, *optional*, defaults to -1): + During distributed training, the rank of the process. + tpu_num_cores (`int`, *optional*): + When training on TPU, the number of TPU cores (automatically passed by launcher script). + debug (`bool`, *optional*, defaults to `False`): + Whether to activate the trace to record computation graphs and profiling information or not. + dataloader_drop_last (`bool`, *optional*, defaults to `False`): + Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) + or not. + eval_steps (`int`, *optional*, defaults to 1000): + Number of update steps before two evaluations. + past_index (`int`, *optional*, defaults to -1): + Some models like [TransformerXL](../model_doc/transformerxl) or :doc*XLNet <../model_doc/xlnet>* can make + use of the past hidden states for their predictions. If this argument is set to a positive int, the + `Trainer` will use the corresponding output (usually index 2) as the past state and feed it to the model at + the next training step under the keyword argument `mems`. + tpu_name (`str`, *optional*): + The name of the TPU the process is running on. + tpu_zone (`str`, *optional*): + The zone of the TPU the process is running on. If not specified, we will attempt to automatically detect + from metadata. + gcp_project (`str`, *optional*): + Google Cloud Project name for the Cloud TPU-enabled project. If not specified, we will attempt to + automatically detect from metadata. + run_name (`str`, *optional*): + A descriptor for the run. Notably used for trackio, wandb, mlflow, comet and swanlab logging. + xla (`bool`, *optional*): + Whether to activate the XLA compilation or not. + """ + + framework = "tf" + tpu_name: Optional[str] = field( + default=None, + metadata={"help": "Name of TPU"}, + ) + + tpu_zone: Optional[str] = field( + default=None, + metadata={"help": "Zone of TPU"}, + ) + + gcp_project: Optional[str] = field( + default=None, + metadata={"help": "Name of Cloud TPU-enabled project"}, + ) + + poly_power: float = field( + default=1.0, + metadata={"help": "Power for the Polynomial decay LR scheduler."}, + ) + + xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"}) + + @cached_property + def _setup_strategy(self) -> tuple["tf.distribute.Strategy", int]: + requires_backends(self, ["tf"]) + logger.info("Tensorflow: setting up strategy") + + gpus = tf.config.list_physical_devices("GPU") + + # Set to float16 at first + if self.fp16: + keras.mixed_precision.set_global_policy("mixed_float16") + + if self.no_cuda: + strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0") + else: + try: + if self.tpu_name: + tpu = tf.distribute.cluster_resolver.TPUClusterResolver( + self.tpu_name, zone=self.tpu_zone, project=self.gcp_project + ) + else: + tpu = tf.distribute.cluster_resolver.TPUClusterResolver() + except ValueError: + if self.tpu_name: + raise RuntimeError(f"Couldn't connect to TPU {self.tpu_name}!") + else: + tpu = None + + if tpu: + # Set to bfloat16 in case of TPU + if self.fp16: + keras.mixed_precision.set_global_policy("mixed_bfloat16") + + tf.config.experimental_connect_to_cluster(tpu) + tf.tpu.experimental.initialize_tpu_system(tpu) + + strategy = tf.distribute.TPUStrategy(tpu) + + elif len(gpus) == 0: + strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0") + elif len(gpus) == 1: + strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0") + elif len(gpus) > 1: + # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` + strategy = tf.distribute.MirroredStrategy() + else: + raise ValueError("Cannot find the proper strategy, please check your environment properties.") + + return strategy + + @property + def strategy(self) -> "tf.distribute.Strategy": + """ + The strategy used for distributed training. + """ + requires_backends(self, ["tf"]) + return self._setup_strategy + + @property + def n_replicas(self) -> int: + """ + The number of replicas (CPUs, GPUs or TPU cores) used in this training. + """ + requires_backends(self, ["tf"]) + return self._setup_strategy.num_replicas_in_sync + + @property + def should_log(self): + """ + Whether or not the current process should produce log. + """ + return False # TF Logging is handled by Keras not the Trainer + + @property + def train_batch_size(self) -> int: + """ + The actual batch size for training (may differ from `per_gpu_train_batch_size` in distributed training). + """ + if self.per_gpu_train_batch_size: + logger.warning( + "Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future " + "version. Using `--per_device_train_batch_size` is preferred." + ) + per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size + return per_device_batch_size * self.n_replicas + + @property + def eval_batch_size(self) -> int: + """ + The actual batch size for evaluation (may differ from `per_gpu_eval_batch_size` in distributed training). + """ + if self.per_gpu_eval_batch_size: + logger.warning( + "Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future " + "version. Using `--per_device_eval_batch_size` is preferred." + ) + per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size + return per_device_batch_size * self.n_replicas + + @property + def n_gpu(self) -> int: + """ + The number of replicas (CPUs, GPUs or TPU cores) used in this training. + """ + requires_backends(self, ["tf"]) + warnings.warn( + "The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead.", + FutureWarning, + ) + return self._setup_strategy.num_replicas_in_sync diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/video_processing_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/video_processing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0bc81bf8eb28f4d668077ac3574ebf14e6f3fefe --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/video_processing_utils.py @@ -0,0 +1,895 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import warnings +from copy import deepcopy +from functools import partial +from typing import Any, Callable, Optional, Union + +import numpy as np + +from .dynamic_module_utils import custom_object_save +from .image_processing_utils import ( + BatchFeature, + get_size_dict, +) +from .image_processing_utils_fast import BaseImageProcessorFast +from .image_utils import ( + ChannelDimension, + SizeDict, + validate_kwargs, +) +from .processing_utils import Unpack, VideosKwargs +from .utils import ( + IMAGE_PROCESSOR_NAME, + PROCESSOR_NAME, + VIDEO_PROCESSOR_NAME, + TensorType, + add_start_docstrings, + copy_func, + download_url, + is_offline_mode, + is_remote_url, + is_torch_available, + is_torchcodec_available, + is_torchvision_v2_available, + logging, +) +from .utils.hub import cached_file +from .utils.import_utils import requires +from .video_utils import ( + VideoInput, + VideoMetadata, + group_videos_by_shape, + is_valid_video, + load_video, + make_batched_metadata, + make_batched_videos, + reorder_videos, + to_channel_dimension_format, +) + + +if is_torch_available(): + import torch + +if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + + +logger = logging.get_logger(__name__) + + +BASE_VIDEO_PROCESSOR_DOCSTRING = r""" + Args: + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the video's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*, defaults to `self.size`): + Size of the output video after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + size_divisor (`int`, *optional*, defaults to `self.size_divisor`): + The size by which to make sure both the height and width can be divided. + default_to_square (`bool`, *optional*, defaults to `self.default_to_square`): + Whether to default to a square video when resizing, if size is an int. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the video. Only has an effect if `do_resize` is set to `True`. Can be + overridden by the `resample` parameter in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the video to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`dict[str, int]` *optional*, defaults to `self.crop_size`): + Size of the output video after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the video by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`): + Scale factor to use if rescaling the video. Only has an effect if `do_rescale` is set to `True`. Can be + overridden by the `rescale_factor` parameter in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the video. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Mean to use if normalizing the video. This is a float or list of floats the length of the number of + channels in the video. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be + overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Standard deviation to use if normalizing the video. This is a float or list of floats the length of the + number of channels in the video. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `self.image_std`): + Whether to convert the video to RGB. + video_metadata (`VideoMetadata`, *optional*): + Metadata of the video containing information about total duration, fps and total number of frames. + do_sample_frames (`int`, *optional*, defaults to `self.do_sample_frames`): + Whether to sample frames from the video before processing or to process the whole video. + num_frames (`int`, *optional*, defaults to `self.num_frames`): + Maximum number of frames to sample when `do_sample_frames=True`. + fps (`int` or `float`, *optional*, defaults to `self.fps`): + Target frames to sample per second when `do_sample_frames=True`. + return_tensors (`str` or `TensorType`, *optional*): + Returns stacked tensors if set to `pt, otherwise returns a list of tensors. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output video. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: video in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input video. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input video. If unset, the channel dimension format is inferred + from the input video. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: video in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: video in (height, width) format. + device (`torch.device`, *optional*): + The device to process the videos on. If unset, the device is inferred from the input videos. + return_metadata (`bool`, *optional*): + Whether to return video metadata or not. + """ + + +@add_start_docstrings( + "Constructs a base VideoProcessor.", + BASE_VIDEO_PROCESSOR_DOCSTRING, +) +@requires(backends=("vision", "torchvision")) +class BaseVideoProcessor(BaseImageProcessorFast): + _auto_class = None + + resample = None + image_mean = None + image_std = None + size = None + size_divisor = None + default_to_square = True + crop_size = None + do_resize = None + do_center_crop = None + do_rescale = None + rescale_factor = 1 / 255 + do_normalize = None + do_convert_rgb = None + do_sample_frames = None + fps = None + num_frames = None + video_metadata = None + return_metadata = False + valid_kwargs = VideosKwargs + model_input_names = ["pixel_values_videos"] + + def __init__(self, **kwargs: Unpack[VideosKwargs]) -> None: + super().__init__() + + self._processor_class = kwargs.pop("processor_class", None) + + # Additional attributes without default values + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + # Prepare size related keys and turn then into `SizeDict` + size = kwargs.pop("size", self.size) + self.size = ( + get_size_dict(size=size, default_to_square=kwargs.pop("default_to_square", self.default_to_square)) + if size is not None + else None + ) + crop_size = kwargs.pop("crop_size", self.crop_size) + self.crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None + + # Save valid kwargs in a list for further processing + self.model_valid_processing_keys = list(self.valid_kwargs.__annotations__.keys()) + for key in self.model_valid_processing_keys: + if kwargs.get(key) is not None: + setattr(self, key, kwargs[key]) + else: + setattr(self, key, deepcopy(getattr(self, key, None))) + + def __call__(self, videos, **kwargs) -> BatchFeature: + return self.preprocess(videos, **kwargs) + + def convert_to_rgb( + self, + video: "torch.Tensor", + ) -> VideoInput: + """ + Converts a video to RGB format. + + Args: + video (`"torch.Tensor"`): + The video to convert. + + Returns: + `torch.Tensor`: The converted video. + """ + + video = F.grayscale_to_rgb(video) + if video.shape[-3] == 3 or not (video[..., 3, :, :] < 255).any(): + return video + + # There is a transparency layer, blend it with a white background. + # Calculate the alpha proportion for blending. + alpha = video[..., 3, :, :] / 255.0 + video = (1 - alpha[..., None, :, :]) * 255 + alpha[..., None, :, :] * video[..., :3, :, :] + return video + + def sample_frames( + self, + metadata: VideoMetadata, + num_frames: Optional[int] = None, + fps: Optional[Union[int, float]] = None, + **kwargs, + ): + """ + Default sampling function which uniformly samples the desired number of frames between 0 and total number of frames. + If `fps` is passed along with metadata, `fps` frames per second are sampled uniformty. Arguments `num_frames` + and `fps` are mutually exclusive. + + Args: + metadata (`VideoMetadata`): + Metadata of the video containing information about total duration, fps and total number of frames. + num_frames (`int`, *optional*): + Maximum number of frames to sample. Defaults to `self.num_frames`. + fps (`int` or `float`, *optional*): + Target frames to sample per second. Defaults to `self.fps`. + + Returns: + np.ndarray: + Indices to sample video frames. + """ + if fps is not None and num_frames is not None: + raise ValueError( + "`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!" + ) + + num_frames = num_frames if num_frames is not None else self.num_frames + fps = fps if fps is not None else self.fps + total_num_frames = metadata.total_num_frames + + # If num_frames is not given but fps is, calculate num_frames from fps + if num_frames is None and fps is not None: + if metadata is None or metadata.fps is None: + raise ValueError( + "Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. " + "Please pass in `VideoMetadata` object or use a fixed `num_frames` per input video" + ) + num_frames = int(total_num_frames / metadata.fps * fps) + + if num_frames > total_num_frames: + raise ValueError( + f"Video can't be sampled. The `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. " + ) + + if num_frames is not None: + indices = torch.arange(0, total_num_frames, total_num_frames / num_frames).int() + else: + indices = torch.arange(0, total_num_frames).int() + return indices + + def _decode_and_sample_videos( + self, + videos: VideoInput, + video_metadata: Union[VideoMetadata, dict], + do_sample_frames: Optional[bool] = None, + sample_indices_fn: Optional[Callable] = None, + ) -> list["torch.Tensor"]: + """ + Decode input videos and sample frames if needed. + """ + videos = make_batched_videos(videos) + video_metadata = make_batched_metadata(videos, video_metadata=video_metadata) + + # Only sample frames if an array video is passed, otherwise first decode -> then sample + if is_valid_video(videos[0]) and do_sample_frames: + sampled_videos = [] + sampled_metadata = [] + for video, metadata in zip(videos, video_metadata): + indices = sample_indices_fn(metadata=metadata) + metadata.frames_indices = indices + sampled_videos.append(video[indices]) + sampled_metadata.append(metadata) + videos = sampled_videos + video_metadata = sampled_metadata + elif not is_valid_video(videos[0]): + if isinstance(videos[0], list): + # Videos sometimes are passed as a list of image URLs, especially through templates + videos = [ + torch.stack([F.pil_to_tensor(image) for image in images], dim=0) + for images in self.fetch_images(videos) + ] + if do_sample_frames: + raise ValueError( + "Sampling frames from a list of images is not supported! Set `do_sample_frames=False`." + ) + else: + videos, video_metadata = self.fetch_videos(videos, sample_indices_fn=sample_indices_fn) + + return videos, video_metadata + + def _prepare_input_videos( + self, + videos: VideoInput, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + device: Optional[str] = None, + ) -> list["torch.Tensor"]: + """ + Prepare the input videos for processing. + """ + processed_videos = [] + for video in videos: + # `make_batched_videos` always returns a 4D array per video + if isinstance(video, np.ndarray): + video = to_channel_dimension_format(video, ChannelDimension.FIRST, input_data_format) + # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays + video = torch.from_numpy(video).contiguous() + + if device is not None: + video = video.to(device) + + processed_videos.append(video) + return processed_videos + + @add_start_docstrings( + BASE_VIDEO_PROCESSOR_DOCSTRING, + ) + def preprocess( + self, + videos: VideoInput, + **kwargs: Unpack[VideosKwargs], + ) -> BatchFeature: + validate_kwargs( + captured_kwargs=kwargs.keys(), + valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"], + ) + # Set default kwargs from self. This ensures that if a kwarg is not provided + # by the user, it gets its default value from the instance, or is set to None. + for kwarg_name in self.valid_kwargs.__annotations__: + kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) + + input_data_format = kwargs.pop("input_data_format") + do_sample_frames = kwargs.pop("do_sample_frames") + device = kwargs.pop("device") + video_metadata = kwargs.pop("video_metadata") + + sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None + videos, video_metadata = self._decode_and_sample_videos( + videos, + video_metadata=video_metadata, + do_sample_frames=do_sample_frames, + sample_indices_fn=sample_indices_fn, + ) + videos = self._prepare_input_videos(videos=videos, input_data_format=input_data_format, device=device) + + kwargs = self._further_process_kwargs(**kwargs) + self._validate_preprocess_kwargs(**kwargs) + + # Pop kwargs that are not needed in _preprocess + kwargs.pop("data_format") + return_metadata = kwargs.pop("return_metadata") + + preprocessed_videos = self._preprocess(videos=videos, **kwargs) + if return_metadata: + preprocessed_videos["video_metadata"] = video_metadata + return preprocessed_videos + + def _preprocess( + self, + videos: list["torch.Tensor"], + do_convert_rgb: bool, + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchFeature: + # Group videos by size for batched resizing + grouped_videos, grouped_videos_index = group_videos_by_shape(videos) + resized_videos_grouped = {} + for shape, stacked_videos in grouped_videos.items(): + if do_convert_rgb: + stacked_videos = self.convert_to_rgb(stacked_videos) + if do_resize: + stacked_videos = self.resize(stacked_videos, size=size, interpolation=interpolation) + resized_videos_grouped[shape] = stacked_videos + resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index) + + # Group videos by size for further processing + # Needed in case do_resize is False, or resize returns videos with different sizes + grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos) + processed_videos_grouped = {} + for shape, stacked_videos in grouped_videos.items(): + if do_center_crop: + stacked_videos = self.center_crop(stacked_videos, crop_size) + # Fused rescale and normalize + stacked_videos = self.rescale_and_normalize( + stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_videos_grouped[shape] = stacked_videos + + processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index) + processed_videos = torch.stack(processed_videos, dim=0) if return_tensors else processed_videos + + return BatchFeature(data={"pixel_values_videos": processed_videos}, tensor_type=return_tensors) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + **kwargs, + ): + r""" + Instantiate a type of [`~video_processing_utils.VideoProcessorBase`] from an video processor. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained video hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a video processor file saved using the + [`~video_processing_utils.VideoProcessorBase.save_pretrained`] method, e.g., + `./my_model_directory/`. + - a path or url to a saved video processor JSON *file*, e.g., + `./my_model_directory/video_preprocessor_config.json`. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model video processor should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the video processor files and override the cached versions if + they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `hf auth login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + + To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. + + + + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + If `False`, then this function returns just the final video processor object. If `True`, then this + functions returns a `Tuple(video_processor, unused_kwargs)` where *unused_kwargs* is a dictionary + consisting of the key/value pairs whose keys are not video processor attributes: i.e., the part of + `kwargs` which has not been used to update `video_processor` and is otherwise ignored. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + kwargs (`dict[str, Any]`, *optional*): + The values in kwargs of any keys which are video processor attributes will be used to override the + loaded values. Behavior concerning key/value pairs whose keys are *not* video processor attributes is + controlled by the `return_unused_kwargs` keyword parameter. + + Returns: + A video processor of type [`~video_processing_utils.ImagVideoProcessorBase`]. + + Examples: + + ```python + # We can't instantiate directly the base class *VideoProcessorBase* so let's show the examples on a + # derived class: *LlavaOnevisionVideoProcessor* + video_processor = LlavaOnevisionVideoProcessor.from_pretrained( + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf" + ) # Download video_processing_config from huggingface.co and cache. + video_processor = LlavaOnevisionVideoProcessor.from_pretrained( + "./test/saved_model/" + ) # E.g. video processor (or model) was saved using *save_pretrained('./test/saved_model/')* + video_processor = LlavaOnevisionVideoProcessor.from_pretrained("./test/saved_model/video_preprocessor_config.json") + video_processor = LlavaOnevisionVideoProcessor.from_pretrained( + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", do_normalize=False, foo=False + ) + assert video_processor.do_normalize is False + video_processor, unused_kwargs = LlavaOnevisionVideoProcessor.from_pretrained( + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", do_normalize=False, foo=False, return_unused_kwargs=True + ) + assert video_processor.do_normalize is False + assert unused_kwargs == {"foo": False} + ```""" + kwargs["cache_dir"] = cache_dir + kwargs["force_download"] = force_download + kwargs["local_files_only"] = local_files_only + kwargs["revision"] = revision + + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None: + kwargs["token"] = token + + video_processor_dict, kwargs = cls.get_video_processor_dict(pretrained_model_name_or_path, **kwargs) + + return cls.from_dict(video_processor_dict, **kwargs) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save an video processor object to the directory `save_directory`, so that it can be re-loaded using the + [`~video_processing_utils.VideoProcessorBase.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the video processor JSON file will be saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + use_auth_token = kwargs.pop("use_auth_token", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if kwargs.get("token") is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + kwargs["token"] = use_auth_token + + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + + # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + custom_object_save(self, save_directory, config=self) + + # If we save using the predefined names, we can load using `from_pretrained` + output_video_processor_file = os.path.join(save_directory, VIDEO_PROCESSOR_NAME) + + self.to_json_file(output_video_processor_file) + logger.info(f"Video processor saved in {output_video_processor_file}") + + if push_to_hub: + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("token"), + ) + + return [output_video_processor_file] + + @classmethod + def get_video_processor_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> tuple[dict[str, Any], dict[str, Any]]: + """ + From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a + video processor of type [`~video_processing_utils.VideoProcessorBase`] using `from_dict`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + + Returns: + `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the video processor object. + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + use_auth_token = kwargs.pop("use_auth_token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", "") + + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + user_agent = {"file_type": "video processor", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if os.path.isfile(pretrained_model_name_or_path): + resolved_video_processor_file = pretrained_model_name_or_path + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + video_processor_file = pretrained_model_name_or_path + resolved_video_processor_file = download_url(pretrained_model_name_or_path) + else: + video_processor_file = VIDEO_PROCESSOR_NAME + try: + # Try to load with a new config name first and if not successful try with the old file name + # NOTE: we will gradually change to saving all processor configs as nested dict in PROCESSOR_NAME + resolved_video_processor_files = [ + resolved_file + for filename in [VIDEO_PROCESSOR_NAME, IMAGE_PROCESSOR_NAME, PROCESSOR_NAME] + if ( + resolved_file := cached_file( + pretrained_model_name_or_path, + filename=filename, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + ) + ) + is not None + ] + resolved_video_processor_file = resolved_video_processor_files[0] + except OSError: + # Raise any OS error raise by `cached_file`. It will have a helpful error message adapted to + # the original exception. + raise + except Exception: + # For any other exception, we throw a generic error. + raise OSError( + f"Can't load video processor for '{pretrained_model_name_or_path}'. If you were trying to load" + " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a {VIDEO_PROCESSOR_NAME} file" + ) + + try: + # Load video_processor dict + with open(resolved_video_processor_file, "r", encoding="utf-8") as reader: + text = reader.read() + video_processor_dict = json.loads(text) + video_processor_dict = video_processor_dict.get("video_processor", video_processor_dict) + + except json.JSONDecodeError: + raise OSError( + f"It looks like the config file at '{resolved_video_processor_file}' is not a valid JSON file." + ) + + if is_local: + logger.info(f"loading configuration file {resolved_video_processor_file}") + else: + logger.info( + f"loading configuration file {video_processor_file} from cache at {resolved_video_processor_file}" + ) + return video_processor_dict, kwargs + + @classmethod + def from_dict(cls, video_processor_dict: dict[str, Any], **kwargs): + """ + Instantiates a type of [`~video_processing_utils.VideoProcessorBase`] from a Python dictionary of parameters. + + Args: + video_processor_dict (`dict[str, Any]`): + Dictionary that will be used to instantiate the video processor object. Such a dictionary can be + retrieved from a pretrained checkpoint by leveraging the + [`~video_processing_utils.VideoProcessorBase.to_dict`] method. + kwargs (`dict[str, Any]`): + Additional parameters from which to initialize the video processor object. + + Returns: + [`~video_processing_utils.VideoProcessorBase`]: The video processor object instantiated from those + parameters. + """ + video_processor_dict = video_processor_dict.copy() + return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + + # The `size` parameter is a dict and was previously an int or tuple in feature extractors. + # We set `size` here directly to the `video_processor_dict` so that it is converted to the appropriate + # dict within the video processor and isn't overwritten if `size` is passed in as a kwarg. + if "size" in kwargs and "size" in video_processor_dict: + video_processor_dict["size"] = kwargs.pop("size") + if "crop_size" in kwargs and "crop_size" in video_processor_dict: + video_processor_dict["crop_size"] = kwargs.pop("crop_size") + + video_processor = cls(**video_processor_dict) + + # Update video_processor with kwargs if needed + to_remove = [] + for key, value in kwargs.items(): + if hasattr(video_processor, key): + setattr(video_processor, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + logger.info(f"Video processor {video_processor}") + if return_unused_kwargs: + return video_processor, kwargs + else: + return video_processor + + def to_dict(self) -> dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + + Returns: + `dict[str, Any]`: Dictionary of all the attributes that make up this video processor instance. + """ + output = deepcopy(self.__dict__) + output.pop("model_valid_processing_keys", None) + output.pop("_valid_kwargs_names", None) + output["video_processor_type"] = self.__class__.__name__ + + return output + + def to_json_string(self) -> str: + """ + Serializes this instance to a JSON string. + + Returns: + `str`: String containing all the attributes that make up this feature_extractor instance in JSON format. + """ + dictionary = self.to_dict() + + for key, value in dictionary.items(): + if isinstance(value, np.ndarray): + dictionary[key] = value.tolist() + + # make sure private name "_processor_class" is correctly + # saved as "processor_class" + _processor_class = dictionary.pop("_processor_class", None) + if _processor_class is not None: + dictionary["processor_class"] = _processor_class + + return json.dumps(dictionary, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this image_processor instance's parameters will be saved. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + @classmethod + def from_json_file(cls, json_file: Union[str, os.PathLike]): + """ + Instantiates a video processor of type [`~video_processing_utils.VideoProcessorBase`] from the path to a JSON + file of parameters. + + Args: + json_file (`str` or `os.PathLike`): + Path to the JSON file containing the parameters. + + Returns: + A video processor of type [`~video_processing_utils.VideoProcessorBase`]: The video_processor object + instantiated from that JSON file. + """ + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + video_processor_dict = json.loads(text) + return cls(**video_processor_dict) + + @classmethod + def register_for_auto_class(cls, auto_class="AutoVideoProcessor"): + """ + Register this class with a given auto class. This should only be used for custom video processors as the ones + in the library are already mapped with `AutoVideoProcessor `. + + + + This API is experimental and may have some slight breaking changes in the next releases. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"AutoVideoProcessor "`): + The auto class to register this new video processor with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class + + def fetch_videos(self, video_url_or_urls: Union[str, list[str], list[list[str]]], sample_indices_fn=None): + """ + Convert a single or a list of urls into the corresponding `np.array` objects. + + If a single url is passed, the return value will be a single object. If a list is passed a list of objects is + returned. + """ + backend = "torchcodec" + if not is_torchcodec_available(): + warnings.warn( + "`torchcodec` is not installed and cannot be used to decode the video by default. " + "Falling back to `torchvision`. Note that `torchvision` decoding is deprecated and will be removed in future versions. " + ) + backend = "torchvision" + + if isinstance(video_url_or_urls, list): + return list(zip(*[self.fetch_videos(x, sample_indices_fn=sample_indices_fn) for x in video_url_or_urls])) + else: + return load_video(video_url_or_urls, backend=backend, sample_indices_fn=sample_indices_fn) + + +BaseVideoProcessor.push_to_hub = copy_func(BaseVideoProcessor.push_to_hub) +if BaseVideoProcessor.push_to_hub.__doc__ is not None: + BaseVideoProcessor.push_to_hub.__doc__ = BaseVideoProcessor.push_to_hub.__doc__.format( + object="video processor", object_class="AutoVideoProcessor", object_files="video processor file" + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/video_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/video_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2ed5720a8e410e061347dd4946fa59f5f9594bb1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/transformers/video_utils.py @@ -0,0 +1,878 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import warnings +from collections.abc import Iterable, Mapping +from contextlib import redirect_stdout +from dataclasses import dataclass, fields +from io import BytesIO +from typing import Callable, NewType, Optional, Union +from urllib.parse import urlparse + +import numpy as np +import requests + +from .image_transforms import PaddingMode, to_channel_dimension_format +from .image_utils import ChannelDimension, infer_channel_dimension_format, is_valid_image +from .utils import ( + is_av_available, + is_cv2_available, + is_decord_available, + is_numpy_array, + is_torch_available, + is_torch_tensor, + is_torchcodec_available, + is_torchvision_available, + is_vision_available, + is_yt_dlp_available, + logging, + requires_backends, +) + + +if is_vision_available(): + import PIL.Image + import PIL.ImageOps + + if is_torchvision_available(): + from torchvision import io as torchvision_io + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + +URL = NewType("URL", str) +Path = NewType("Path", str) + +VideoInput = Union[ + list["PIL.Image.Image"], + np.ndarray, + "torch.Tensor", + list[np.ndarray], + list["torch.Tensor"], + list[list["PIL.Image.Image"]], + list[list[np.ndarray]], + list[list["torch.Tensor"]], + URL, + list[URL], + list[list[URL]], + Path, + list[Path], + list[list[Path]], +] + + +@dataclass +class VideoMetadata(Mapping): + total_num_frames: int + fps: Optional[float] = None + width: Optional[int] = None + height: Optional[int] = None + duration: Optional[float] = None + video_backend: Optional[str] = None + frames_indices: Optional[list[int]] = None + + def __iter__(self): + return (f.name for f in fields(self)) + + def __len__(self): + return len(fields(self)) + + def __getitem__(self, item): + return getattr(self, item) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + @property + def timestamps(self) -> list[float]: + "Timestamps of the sampled frames in seconds." + if self.fps is None or self.frames_indices is None: + raise ValueError("Cannot infer video `timestamps` when `fps` or `frames_indices` is None.") + return [frame_idx / self.fps for frame_idx in self.frames_indices] + + def update(self, dictionary): + for key, value in dictionary.items(): + if hasattr(self, key): + setattr(self, key, value) + + +def is_valid_video_frame(frame): + return isinstance(frame, PIL.Image.Image) or ( + (is_numpy_array(frame) or is_torch_tensor(frame)) and frame.ndim == 3 + ) + + +def is_valid_video(video): + if not isinstance(video, (list, tuple)): + return (is_numpy_array(video) or is_torch_tensor(video)) and video.ndim == 4 + return video and all(is_valid_video_frame(frame) for frame in video) + + +def valid_videos(videos): + # If we have a list of videos, it could be either one video as list of frames or a batch + if isinstance(videos, (list, tuple)): + for video_or_frame in videos: + if not (is_valid_video(video_or_frame) or is_valid_video_frame(video_or_frame)): + return False + # If not a list, then we have a single 4D video or 5D batched tensor + elif not is_valid_video(videos) or videos.ndim == 5: + return False + return True + + +def is_batched_video(videos): + if isinstance(videos, (list, tuple)): + return is_valid_video(videos[0]) + elif (is_numpy_array(videos) or is_torch_tensor(videos)) and videos.ndim == 5: + return True + return False + + +def is_scaled_video(video: np.ndarray) -> bool: + """ + Checks to see whether the pixel values have already been rescaled to [0, 1]. + """ + # It's possible the video has pixel values in [0, 255] but is of floating type + return np.min(video) >= 0 and np.max(video) <= 1 + + +def convert_pil_frames_to_video(videos: list[VideoInput]) -> list[Union[np.ndarray, "torch.Tensor"]]: + """ + Given a batch of videos, converts each video to a 4D array. If video is already in array type, + it is simply returned. We assume that all inputs in the list are in the same format, based on the type of the first element. + + Args: + videos (`VideoInput`): + Video inputs to turn into a list of videos. + """ + + if not (isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0])): + return videos + + video_converted = [] + for video in videos: + video = [np.array(frame) for frame in video] + video = np.stack(video) + video_converted.append(video) + return video_converted + + +def make_batched_videos(videos) -> list[Union[np.ndarray, "torch.Tensor", "URL", "Path"]]: + """ + Ensure that the input is a list of videos. If the input is a single video, it is converted to a list of length 1. + If the input is a batch of videos, it is converted to a list of 4D video arrays. Videos passed as list `PIL.Image` + frames are converted to 4D arrays. + + We assume that all inputs in the list are in the same format, based on the type of the first element. + + Args: + videos (`VideoInput`): + Video inputs to turn into a list of videos. + """ + # Early exit for deeply nested list of image frame paths. We shouldn't flatten them + try: + if isinstance(videos[0][0], list) and isinstance(videos[0][0][0], str): + return [image_paths for sublist in videos for image_paths in sublist] + except (IndexError, TypeError): + pass + + if isinstance(videos, str) or is_valid_video(videos): + return convert_pil_frames_to_video([videos]) + # only one frame passed, thus we unsqueeze time dim + elif is_valid_image(videos): + if isinstance(videos, PIL.Image.Image): + videos = np.array(videos) + return [videos[None, ...]] + elif not isinstance(videos, list): + raise ValueError( + f"Invalid video input. Expected either a list of video frames or an input of 4 or 5 dimensions, but got" + f" type {type(videos)}." + ) + + # Recursively flatten any nested structure + flat_videos_list = [] + for item in videos: + if isinstance(item, str) or is_valid_video(item): + flat_videos_list.append(item) + elif isinstance(item, list) and item: + flat_videos_list.extend(make_batched_videos(item)) + + flat_videos_list = convert_pil_frames_to_video(flat_videos_list) + return flat_videos_list + + +def make_batched_metadata(videos: VideoInput, video_metadata: Union[VideoMetadata, dict]): + if video_metadata is None: + # Create default metadata and fill attributes we can infer from given video + video_metadata = [ + { + "total_num_frames": len(video), + "fps": None, + "duration": None, + "frames_indices": list(range(len(video))), + "height": get_video_size(video)[0] if is_valid_video(video) else None, + "width": get_video_size(video)[1] if is_valid_video(video) else None, + } + for video in videos + ] + + if isinstance(video_metadata, list): + # Flatten if nested list + if isinstance(video_metadata[0], list): + video_metadata = [ + VideoMetadata(**metadata) for metadata_list in video_metadata for metadata in metadata_list + ] + # Simply wrap in VideoMetadata if simple dict + elif isinstance(video_metadata[0], dict): + video_metadata = [VideoMetadata(**metadata) for metadata in video_metadata] + else: + # Create a batched list from single object + video_metadata = [VideoMetadata(**video_metadata)] + return video_metadata + + +def get_video_size(video: np.ndarray, channel_dim: Optional[ChannelDimension] = None) -> tuple[int, int]: + """ + Returns the (height, width) dimensions of the video. + + Args: + video (`np.ndarray`): + The video to get the dimensions of. + channel_dim (`ChannelDimension`, *optional*): + Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the video. + + Returns: + A tuple of the video's height and width. + """ + if channel_dim is None: + channel_dim = infer_channel_dimension_format(video, num_channels=(1, 3, 4)) + + if channel_dim == ChannelDimension.FIRST: + return video.shape[-2], video.shape[-1] + elif channel_dim == ChannelDimension.LAST: + return video.shape[-3], video.shape[-2] + else: + raise ValueError(f"Unsupported data format: {channel_dim}") + + +def get_uniform_frame_indices(total_num_frames: int, num_frames: Optional[int] = None): + """ + Creates a numpy array for uniform sampling of `num_frame` frames from `total_num_frames` + when loading a video. + + Args: + total_num_frames (`int`): + Total number of frames that a video has. + num_frames (`int`, *optional*): + Number of frames to sample uniformly. If not specified, all frames are sampled. + + Returns: + np.ndarray: np array of frame indices that will be sampled. + """ + if num_frames is not None: + indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int) + else: + indices = np.arange(0, total_num_frames).astype(int) + return indices + + +def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs): + """ + A default sampling function that replicates the logic used in get_uniform_frame_indices, + while optionally handling `fps` if `num_frames` is not provided. + + Args: + metadata (`VideoMetadata`): + `VideoMetadata` object containing metadata about the video, such as "total_num_frames" or "fps". + num_frames (`int`, *optional*): + Number of frames to sample uniformly. + fps (`int` or `float`, *optional*): + Desired frames per second. Takes priority over num_frames if both are provided. + + Returns: + `np.ndarray`: Array of frame indices to sample. + """ + total_num_frames = metadata.total_num_frames + video_fps = metadata.fps + + # If num_frames is not given but fps is, calculate num_frames from fps + if num_frames is None and fps is not None: + num_frames = int(total_num_frames / video_fps * fps) + if num_frames > total_num_frames: + raise ValueError( + f"When loading the video with fps={fps}, we computed num_frames={num_frames} " + f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata." + ) + + if num_frames is not None: + indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int) + else: + indices = np.arange(0, total_num_frames, dtype=int) + return indices + + +def read_video_opencv( + video_path: Union["URL", "Path"], + sample_indices_fn: Callable, + **kwargs, +) -> tuple[np.ndarray, VideoMetadata]: + """ + Decode a video using the OpenCV backend. + + Args: + video_path (`str`): + Path to the video file. + sample_indices_fn (`Callable`): + A callable function that will return indices at which the video should be sampled. If the video has to be loaded using + by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. + If not provided, simple uniform sampling with fps is performed. + Example: + def sample_indices_fn(metadata, **kwargs): + return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) + + Returns: + tuple[`np.ndarray`, `VideoMetadata`]: A tuple containing: + - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). + - `VideoMetadata` object. + """ + # Lazy import cv2 + requires_backends(read_video_opencv, ["cv2"]) + import cv2 + + video = cv2.VideoCapture(video_path) + total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + video_fps = video.get(cv2.CAP_PROP_FPS) + duration = total_num_frames / video_fps if video_fps else 0 + metadata = VideoMetadata( + total_num_frames=int(total_num_frames), + fps=float(video_fps), + duration=float(duration), + video_backend="opencv", + height=int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)), + width=int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), + ) + indices = sample_indices_fn(metadata=metadata, **kwargs) + + index = 0 + frames = [] + while video.isOpened(): + success, frame = video.read() + if not success: + break + if index in indices: + height, width, channel = frame.shape + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(frame[0:height, 0:width, 0:channel]) + if success: + index += 1 + if index >= total_num_frames: + break + + video.release() + metadata.frames_indices = indices + return np.stack(frames), metadata + + +def read_video_decord( + video_path: Union["URL", "Path"], + sample_indices_fn: Callable, + **kwargs, +): + """ + Decode a video using the Decord backend. + + Args: + video_path (`str`): + Path to the video file. + sample_indices_fn (`Callable`): + A callable function that will return indices at which the video should be sampled. If the video has to be loaded using + by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. + If not provided, simple uniform sampling with fps is performed. + Example: + def sample_indices_fn(metadata, **kwargs): + return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) + + Returns: + tuple[`np.array`, `VideoMetadata`]: A tuple containing: + - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). + - `VideoMetadata` object. + """ + # Lazy import from decord + requires_backends(read_video_decord, ["decord"]) + from decord import VideoReader, cpu + + vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu + video_fps = vr.get_avg_fps() + total_num_frames = len(vr) + duration = total_num_frames / video_fps if video_fps else 0 + metadata = VideoMetadata( + total_num_frames=int(total_num_frames), + fps=float(video_fps), + duration=float(duration), + video_backend="decord", + ) + + indices = sample_indices_fn(metadata=metadata, **kwargs) + video = vr.get_batch(indices).asnumpy() + + metadata.update( + { + "frames_indices": indices, + "height": video.shape[1], + "width": video.shape[2], + } + ) + return video, metadata + + +def read_video_pyav( + video_path: Union["URL", "Path"], + sample_indices_fn: Callable, + **kwargs, +): + """ + Decode the video with PyAV decoder. + + Args: + video_path (`str`): + Path to the video file. + sample_indices_fn (`Callable`, *optional*): + A callable function that will return indices at which the video should be sampled. If the video has to be loaded using + by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. + If not provided, simple uniform sampling with fps is performed. + Example: + def sample_indices_fn(metadata, **kwargs): + return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) + + Returns: + tuple[`np.array`, `VideoMetadata`]: A tuple containing: + - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). + - `VideoMetadata` object. + """ + # Lazy import av + requires_backends(read_video_pyav, ["av"]) + import av + + container = av.open(video_path) + total_num_frames = container.streams.video[0].frames + video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`? + duration = total_num_frames / video_fps if video_fps else 0 + metadata = VideoMetadata( + total_num_frames=int(total_num_frames), + fps=float(video_fps), + duration=float(duration), + video_backend="pyav", + height=container.streams.video[0].height, + width=container.streams.video[0].width, + ) + indices = sample_indices_fn(metadata=metadata, **kwargs) + + frames = [] + container.seek(0) + end_index = indices[-1] + for i, frame in enumerate(container.decode(video=0)): + if i > end_index: + break + if i >= 0 and i in indices: + frames.append(frame) + + video = np.stack([x.to_ndarray(format="rgb24") for x in frames]) + metadata.frames_indices = indices + return video, metadata + + +def read_video_torchvision( + video_path: Union["URL", "Path"], + sample_indices_fn: Callable, + **kwargs, +): + """ + Decode the video with torchvision decoder. + + Args: + video_path (`str`): + Path to the video file. + sample_indices_fn (`Callable`, *optional*): + A callable function that will return indices at which the video should be sampled. If the video has to be loaded using + by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. + If not provided, simple uniform sampling with fps is performed. + Example: + def sample_indices_fn(metadata, **kwargs): + return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) + + Returns: + tuple[`torch.Tensor`, `VideoMetadata`]: A tuple containing: + - Torch tensor of frames in RGB (shape: [num_frames, height, width, 3]). + - `VideoMetadata` object. + """ + warnings.warn( + "Using `torchvision` for video decoding is deprecated and will be removed in future versions. " + "Please use `torchcodec` instead." + ) + video, _, info = torchvision_io.read_video( + video_path, + start_pts=0.0, + end_pts=None, + pts_unit="sec", + output_format="TCHW", + ) + video_fps = info["video_fps"] + total_num_frames = video.size(0) + duration = total_num_frames / video_fps if video_fps else 0 + metadata = VideoMetadata( + total_num_frames=int(total_num_frames), + fps=float(video_fps), + duration=float(duration), + video_backend="torchvision", + ) + + indices = sample_indices_fn(metadata=metadata, **kwargs) + + video = video[indices].contiguous() + metadata.update( + { + "frames_indices": indices, + "height": video.shape[2], + "width": video.shape[3], + } + ) + return video, metadata + + +def read_video_torchcodec( + video_path: Union["URL", "Path"], + sample_indices_fn: Callable, + **kwargs, +): + """ + Decode the video with torchcodec decoder. + + Args: + video_path (`str`): + Path to the video file. + sample_indices_fn (`Callable`): + A callable function that will return indices at which the video should be sampled. If the video has to be loaded using + by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. + If not provided, simple uniform sampling with fps is performed. + Example: + def sample_indices_fn(metadata, **kwargs): + return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) + + Returns: + Tuple[`torch.Tensor`, `VideoMetadata`]: A tuple containing: + - Torch tensor of frames in RGB (shape: [num_frames, height, width, 3]). + - `VideoMetadata` object. + """ + # Lazy import torchcodec + requires_backends(read_video_torchcodec, ["torchcodec"]) + from torchcodec.decoders import VideoDecoder + + decoder = VideoDecoder( + video_path, + # Interestingly `exact` mode takes less than approximate when we load the whole video + seek_mode="exact", + # Allow FFmpeg decide on the number of threads for efficiency + num_ffmpeg_threads=0, + device=kwargs.get("device"), + ) + metadata = VideoMetadata( + total_num_frames=decoder.metadata.num_frames, + fps=decoder.metadata.average_fps, + duration=decoder.metadata.duration_seconds, + video_backend="torchcodec", + height=decoder.metadata.height, + width=decoder.metadata.width, + ) + indices = sample_indices_fn(metadata=metadata, **kwargs) + + video = decoder.get_frames_at(indices=indices).data.contiguous() + metadata.frames_indices = indices + return video, metadata + + +VIDEO_DECODERS = { + "decord": read_video_decord, + "opencv": read_video_opencv, + "pyav": read_video_pyav, + "torchvision": read_video_torchvision, + "torchcodec": read_video_torchcodec, +} + + +def load_video( + video: VideoInput, + num_frames: Optional[int] = None, + fps: Optional[Union[int, float]] = None, + backend: str = "pyav", + sample_indices_fn: Optional[Callable] = None, + **kwargs, +) -> np.ndarray: + """ + Loads `video` to a numpy array. + + Args: + video (`VideoInput`): + The video to convert to the numpy array format. Can be a link to video or local path. + num_frames (`int`, *optional*): + Number of frames to sample uniformly. If not passed, the whole video is loaded. + fps (`int` or `float`, *optional*): + Number of frames to sample per second. Should be passed only when `num_frames=None`. + If not specified and `num_frames==None`, all frames are sampled. + backend (`str`, *optional*, defaults to `"pyav"`): + The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision", "torchcodec"]. Defaults to "pyav". + sample_indices_fn (`Callable`, *optional*): + A callable function that will return indices at which the video should be sampled. If the video has to be loaded using + by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. + If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args. + The function expects at input the all args along with all kwargs passed to `load_video` and should output valid + indices at which the video should be sampled. For example: + + Example: + def sample_indices_fn(metadata, **kwargs): + return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) + + Returns: + tuple[`np.ndarray`, Dict]: A tuple containing: + - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). + - Metadata dictionary. + """ + + # If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn` + if fps is not None and num_frames is not None and sample_indices_fn is None: + raise ValueError( + "`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!" + ) + + # If user didn't pass a sampling function, create one on the fly with default logic + if sample_indices_fn is None: + + def sample_indices_fn_func(metadata, **fn_kwargs): + return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs) + + sample_indices_fn = sample_indices_fn_func + + # Early exit if provided an array or `PIL` frames + if not isinstance(video, str): + metadata = [None] * len(video) + return video, metadata + + if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]: + if not is_yt_dlp_available(): + raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.") + # Lazy import from yt_dlp + requires_backends(load_video, ["yt_dlp"]) + from yt_dlp import YoutubeDL + + buffer = BytesIO() + with redirect_stdout(buffer), YoutubeDL() as f: + f.download([video]) + bytes_obj = buffer.getvalue() + file_obj = BytesIO(bytes_obj) + elif video.startswith("http://") or video.startswith("https://"): + file_obj = BytesIO(requests.get(video).content) + elif os.path.isfile(video): + file_obj = video + else: + raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.") + + # can also load with decord, but not cv2/torchvision + # both will fail in case of url links + video_is_url = video.startswith("http://") or video.startswith("https://") + if video_is_url and backend == "opencv": + raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend") + + if ( + (not is_decord_available() and backend == "decord") + or (not is_av_available() and backend == "pyav") + or (not is_cv2_available() and backend == "opencv") + or (not is_torchvision_available() and backend == "torchvision") + or (not is_torchcodec_available() and backend == "torchcodec") + ): + raise ImportError( + f"You chose backend={backend} for loading the video but the required library is not found in your environment " + f"Make sure to install {backend} before loading the video." + ) + + video_decoder = VIDEO_DECODERS[backend] + video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs) + return video, metadata + + +def convert_to_rgb( + video: np.ndarray, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> np.ndarray: + """ + Convert video to RGB by blending the transparency layer if it's in RGBA format, otherwise simply returns it. + + Args: + video (`np.ndarray`): + The video to convert. + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format of the input video. If unset, will use the inferred format from the input. + """ + if not isinstance(video, np.ndarray): + raise TypeError(f"Video has to be a numpy array to convert to RGB format, but found {type(video)}") + + # np.array usually comes with ChannelDimension.LAST so let's convert it + if input_data_format is None: + input_data_format = infer_channel_dimension_format(video) + video = to_channel_dimension_format(video, ChannelDimension.FIRST, input_channel_dim=input_data_format) + + # 3 channels for RGB already + if video.shape[-3] == 3: + return video + + # Grayscale video so we repeat it 3 times for each channel + if video.shape[-3] == 1: + return video.repeat(3, -3) + + if not (video[..., 3, :, :] < 255).any(): + return video + + # There is a transparency layer, blend it with a white background. + # Calculate the alpha proportion for blending. + alpha = video[..., 3, :, :] / 255.0 + video = (1 - alpha[..., None, :, :]) * 255 + alpha[..., None, :, :] * video[..., 3, :, :] + return video + + +def pad( + video: np.ndarray, + padding: Union[int, tuple[int, int], Iterable[tuple[int, int]]], + mode: PaddingMode = PaddingMode.CONSTANT, + constant_values: Union[float, Iterable[float]] = 0.0, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> np.ndarray: + """ + Pads the `video` with the specified (height, width) `padding` and `mode`. + + Args: + video (`np.ndarray`): + The video to pad. + padding (`int` or `tuple[int, int]` or `Iterable[tuple[int, int]]`): + Padding to apply to the edges of the height, width axes. Can be one of three formats: + - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis. + - `((before, after),)` yields same before and after pad for height and width. + - `(pad,)` or int is a shortcut for before = after = pad width for all axes. + mode (`PaddingMode`): + The padding mode to use. Can be one of: + - `"constant"`: pads with a constant value. + - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the + vector along each axis. + - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis. + - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array. + constant_values (`float` or `Iterable[float]`, *optional*): + The value to use for the padding if `mode` is `"constant"`. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output video. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_frames, num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: video in (num_frames, height, width, num_channels) format. + If unset, will use same as the input video. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input video. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_frames, num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: video in (num_frames, height, width, num_channels) format. + If unset, will use the inferred format of the input video. + + Returns: + `np.ndarray`: The padded video. + + """ + if input_data_format is None: + input_data_format = infer_channel_dimension_format(video) + + def _expand_for_data_format(values): + """ + Convert values to be in the format expected by np.pad based on the data format. + """ + if isinstance(values, (int, float)): + values = ((values, values), (values, values)) + elif isinstance(values, tuple) and len(values) == 1: + values = ((values[0], values[0]), (values[0], values[0])) + elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int): + values = (values, values) + elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple): + pass + else: + raise ValueError(f"Unsupported format: {values}") + + # add 0 for channel dimension + values = ( + ((0, 0), (0, 0), *values) if input_data_format == ChannelDimension.FIRST else ((0, 0), *values, (0, 0)) + ) + + # Add additional padding if there's a batch dimension + values = (0, *values) if video.ndim == 5 else values + return values + + padding_map = { + PaddingMode.CONSTANT: "constant", + PaddingMode.REFLECT: "reflect", + PaddingMode.REPLICATE: "replicate", + PaddingMode.SYMMETRIC: "symmetric", + } + padding = _expand_for_data_format(padding) + + pad_kwargs = {} + if mode not in padding_map: + raise ValueError(f"Invalid padding mode: {mode}") + elif mode == PaddingMode.CONSTANT: + pad_kwargs["constant_values"] = _expand_for_data_format(constant_values) + + video = np.pad(video, padding, mode=padding_map[mode], **pad_kwargs) + video = to_channel_dimension_format(video, data_format, input_data_format) if data_format is not None else video + return video + + +def group_videos_by_shape( + videos: list["torch.Tensor"], +) -> tuple[dict[tuple[int, int], "torch.Tensor"], dict[int, tuple[tuple[int, int], int]]]: + """ + Groups videos by shape. + Returns a dictionary with the shape as key and a list of videos with that shape as value, + and a dictionary with the index of the video in the original list as key and the shape and index in the grouped list as value. + """ + grouped_videos = {} + grouped_videos_index = {} + for i, video in enumerate(videos): + shape = video.shape[-2::] + num_frames = video.shape[-4] # video format BTCHW + shape = (num_frames, *shape) + if shape not in grouped_videos: + grouped_videos[shape] = [] + grouped_videos[shape].append(video) + grouped_videos_index[i] = (shape, len(grouped_videos[shape]) - 1) + # stack videos with the same size and number of frames + grouped_videos = {shape: torch.stack(videos, dim=0) for shape, videos in grouped_videos.items()} + return grouped_videos, grouped_videos_index + + +def reorder_videos( + processed_videos: dict[tuple[int, int], "torch.Tensor"], + grouped_videos_index: dict[int, tuple[tuple[int, int], int]], +) -> list["torch.Tensor"]: + """ + Reconstructs a list of videos in the original order. + """ + return [ + processed_videos[grouped_videos_index[i][0]][grouped_videos_index[i][1]] + for i in range(len(grouped_videos_index)) + ]